From 4ccfee47cc876797640ea529b84090b5f0377ccb Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Tue, 23 Jan 2024 00:35:22 +0800 Subject: [PATCH] Add `export_tflite` and `export_onnx` and support `channels_first` (#28) --- .github/workflows/actions.yml | 1 + .gitignore | 6 +- conftest.py | 10 ++ kimm/__init__.py | 1 + kimm/blocks/base_block.py | 19 +++- kimm/blocks/depthwise_separation_block.py | 4 +- kimm/blocks/inverted_residual_block.py | 4 +- kimm/blocks/transformer_block.py | 8 +- kimm/export/__init__.py | 2 + kimm/export/export_onnx.py | 133 ++++++++++++++++++++++ kimm/export/export_onnx_test.py | 30 +++++ kimm/export/export_tflite.py | 72 ++++++++++++ kimm/export/export_tflite_test.py | 66 +++++++++++ kimm/layers/attention_test.py | 2 + kimm/layers/layer_scale.py | 27 ++++- kimm/layers/layer_scale_test.py | 4 +- kimm/layers/position_embedding_test.py | 3 + kimm/models/base_model.py | 9 +- kimm/models/convmixer.py | 18 ++- kimm/models/convnext.py | 32 ++++-- kimm/models/densenet.py | 29 +++-- kimm/models/efficientnet.py | 10 +- kimm/models/ghostnet.py | 37 ++++-- kimm/models/inception_v3.py | 51 +++++++-- kimm/models/mobilevit.py | 17 ++- kimm/models/models_test.py | 55 ++++++++- kimm/models/regnet.py | 4 +- kimm/models/resnet.py | 7 +- kimm/models/vgg.py | 6 + kimm/models/vision_transformer.py | 9 +- kimm/models/xception.py | 19 +++- pyproject.toml | 7 +- 32 files changed, 623 insertions(+), 79 deletions(-) create mode 100644 kimm/export/__init__.py create mode 100644 kimm/export/export_onnx.py create mode 100644 kimm/export/export_onnx_test.py create mode 100644 kimm/export/export_tflite.py create mode 100644 kimm/export/export_tflite_test.py diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index c00266f..4dd859f 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -22,6 +22,7 @@ jobs: runs-on: ubuntu-latest env: PYTHON: ${{ matrix.python-version }} + KERAS_BACKEND: ${{ matrix.backend }} steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/.gitignore b/.gitignore index 0ed7f34..aa7fe58 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,8 @@ cython_debug/ # Keras *.keras -exported \ No newline at end of file +exported + +# Exported model +*.tflite +*.onnx diff --git a/conftest.py b/conftest.py index fce612b..3ce6930 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ import os import pytest +from keras import backend def pytest_addoption(parser): @@ -27,6 +28,10 @@ def pytest_configure(config): config.addinivalue_line( "markers", "serialization: mark test as a serialization test" ) + config.addinivalue_line( + "markers", + "requires_trainable_backend: mark test for trainable backend only", + ) def pytest_collection_modifyitems(config, items): @@ -35,6 +40,11 @@ def pytest_collection_modifyitems(config, items): not run_serialization_tests, reason="need --run_serialization option to run", ) + requires_trainable_backend = pytest.mark.skipif( + backend.backend() == "numpy", reason="require trainable backend" + ) for item in items: + if "requires_trainable_backend" in item.keywords: + item.add_marker(requires_trainable_backend) if "serialization" in item.name: item.add_marker(skip_serialization) diff --git a/kimm/__init__.py b/kimm/__init__.py index 57a48e1..01dbccf 100644 --- a/kimm/__init__.py +++ b/kimm/__init__.py @@ -1,3 +1,4 @@ +from kimm import export from kimm import models # force to add models to the registry from kimm.utils.model_registry import list_models diff --git a/kimm/blocks/base_block.py b/kimm/blocks/base_block.py index cc93599..430745f 100644 --- a/kimm/blocks/base_block.py +++ b/kimm/blocks/base_block.py @@ -1,5 +1,6 @@ import typing +from keras import backend from keras import layers from kimm.utils import make_divisible @@ -43,7 +44,9 @@ def apply_conv2d_block( ) if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] - input_channels = inputs.shape[-1] + + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] has_skip = add_skip and strides == 1 and input_channels == filters x = inputs @@ -74,7 +77,10 @@ def apply_conv2d_block( name=f"{name}_dwconv2d", )(x) x = layers.BatchNormalization( - name=f"{name}_bn", momentum=bn_momentum, epsilon=bn_epsilon + axis=channels_axis, + name=f"{name}_bn", + momentum=bn_momentum, + epsilon=bn_epsilon, )(x) x = apply_activation(x, activation, name=name) if has_skip: @@ -91,7 +97,8 @@ def apply_se_block( se_input_channels: typing.Optional[int] = None, name: str = "se_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] if se_input_channels is None: se_input_channels = input_channels if make_divisible_number is None: @@ -102,7 +109,11 @@ def apply_se_block( ) x = inputs - x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x) + x = layers.GlobalAveragePooling2D( + data_format=backend.image_data_format(), + keepdims=True, + name=f"{name}_mean", + )(x) x = layers.Conv2D( se_channels, 1, use_bias=True, name=f"{name}_conv_reduce" )(x) diff --git a/kimm/blocks/depthwise_separation_block.py b/kimm/blocks/depthwise_separation_block.py index 24b9b47..a70ecee 100644 --- a/kimm/blocks/depthwise_separation_block.py +++ b/kimm/blocks/depthwise_separation_block.py @@ -1,5 +1,6 @@ import typing +from keras import backend from keras import layers from kimm.blocks.base_block import apply_conv2d_block @@ -23,7 +24,8 @@ def apply_depthwise_separation_block( padding: typing.Optional[typing.Literal["same", "valid"]] = None, name: str = "depthwise_separation_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] has_skip = skip and (strides == 1 and input_channels == output_channels) x = inputs diff --git a/kimm/blocks/inverted_residual_block.py b/kimm/blocks/inverted_residual_block.py index 637099c..b5dc95c 100644 --- a/kimm/blocks/inverted_residual_block.py +++ b/kimm/blocks/inverted_residual_block.py @@ -1,5 +1,6 @@ import typing +from keras import backend from keras import layers from kimm.blocks.base_block import apply_conv2d_block @@ -25,7 +26,8 @@ def apply_inverted_residual_block( padding: typing.Optional[typing.Literal["same", "valid"]] = None, name: str = "inverted_residual_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] hidden_channels = make_divisible(input_channels * expansion_ratio) has_skip = strides == 1 and input_channels == output_channels diff --git a/kimm/blocks/transformer_block.py b/kimm/blocks/transformer_block.py index 34c1699..42bb60b 100644 --- a/kimm/blocks/transformer_block.py +++ b/kimm/blocks/transformer_block.py @@ -1,5 +1,6 @@ import typing +from keras import backend from keras import layers from kimm import layers as kimm_layers @@ -13,9 +14,13 @@ def apply_mlp_block( use_bias: bool = True, dropout_rate: float = 0.0, use_conv_mlp: bool = False, + data_format: typing.Optional[str] = None, name: str = "mlp_block", ): - input_dim = inputs.shape[-1] + if data_format is None: + data_format = backend.image_data_format() + dim_axis = -1 if data_format == "channels_last" else 1 + input_dim = inputs.shape[dim_axis] output_dim = output_dim or input_dim x = inputs @@ -71,6 +76,7 @@ def apply_transformer_block( int(dim * mlp_ratio), activation=activation, dropout_rate=projection_dropout_rate, + data_format="channels_last", # TODO: let backend decides name=f"{name}_mlp", ) x = layers.Add()([residual_2, x]) diff --git a/kimm/export/__init__.py b/kimm/export/__init__.py new file mode 100644 index 0000000..a3000a7 --- /dev/null +++ b/kimm/export/__init__.py @@ -0,0 +1,2 @@ +from kimm.export.export_onnx import export_onnx +from kimm.export.export_tflite import export_tflite diff --git a/kimm/export/export_onnx.py b/kimm/export/export_onnx.py new file mode 100644 index 0000000..c0d4833 --- /dev/null +++ b/kimm/export/export_onnx.py @@ -0,0 +1,133 @@ +import pathlib +import tempfile +import typing + +from keras import backend +from keras import layers +from keras import models +from keras import ops + +from kimm.models import BaseModel + + +def _export_onnx_tf( + model: BaseModel, + inputs_as_nchw, + export_path: typing.Union[str, pathlib.Path], +): + try: + import tf2onnx + import tf2onnx.tf_loader + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Failed to import 'tf2onnx'. Please install it by the following " + "instruction:\n'pip install tf2onnx'" + ) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir, "temp_saved_model") + model.export(temp_path) + + ( + graph_def, + inputs, + outputs, + tensors_to_rename, + ) = tf2onnx.tf_loader.from_saved_model( + temp_path, + None, + None, + return_tensors_to_rename=True, + ) + + tf2onnx.convert.from_graph_def( + graph_def, + input_names=inputs, + output_names=outputs, + output_path=export_path, + inputs_as_nchw=inputs_as_nchw, + tensors_to_rename=tensors_to_rename, + ) + + +def _export_onnx_torch( + model: BaseModel, + input_shape: typing.Union[int, int, int], + export_path: typing.Union[str, pathlib.Path], +): + try: + import torch + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Failed to import 'torch'. Please install it before calling" + "`export_onnx` using torch backend" + ) + full_input_shape = [1] + list(input_shape) + dummy_inputs = ops.ones(full_input_shape) + scripted_model = torch.jit.trace(model, dummy_inputs).eval() + torch.onnx.export(scripted_model, dummy_inputs, export_path) + + +def export_onnx( + model: BaseModel, + input_shape: typing.Union[int, typing.Sequence[int]], + export_path: typing.Union[str, pathlib.Path], + batch_size: int = 1, + use_nchw: bool = True, +): + if backend.backend() not in ("tensorflow", "torch"): + raise ValueError( + "Currently, `export_onnx` only supports tensorflow and torch " + "backend" + ) + try: + import onnx + import onnxoptimizer + import onnxsim + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Failed to import 'onnx', 'onnxsim' or 'onnxoptimizer'. Please " + "install them by the following instruction:\n" + "'pip install onnx onnxsim onnxoptimizer'" + ) + + if isinstance(input_shape, int): + input_shape = [input_shape, input_shape, 3] + elif len(input_shape) == 2: + input_shape = [input_shape[0], input_shape[1], 3] + elif len(input_shape) == 3: + input_shape = input_shape + if use_nchw: + if backend.backend() == "torch": + raise ValueError( + "Currently, torch backend doesn't support `use_nchw=True`. " + "You can use tensorflow backend to overcome this issue or " + "set `use_nchw=False`. " + "Note that there might be a significant performance " + "degradation when using torch backend to export onnx due to " + "the pre- and post-transpose of the Conv2D." + ) + elif backend.backend() == "tensorflow": + inputs_as_nchw = ["inputs"] + else: + inputs_as_nchw = None + else: + inputs_as_nchw = None + + # Fix input shape + inputs = layers.Input( + shape=input_shape, batch_size=batch_size, name="inputs" + ) + outputs = model(inputs, training=False) + model = models.Model(inputs, outputs) + + if backend.backend() == "tensorflow": + _export_onnx_tf(model, inputs_as_nchw, export_path) + elif backend.backend() == "torch": + _export_onnx_torch(model, input_shape, export_path) + + # Further optimization + model = onnx.load(export_path) + model_simp, _ = onnxsim.simplify(model) + model_simp = onnxoptimizer.optimize(model_simp) + onnx.save(model_simp, export_path) diff --git a/kimm/export/export_onnx_test.py b/kimm/export/export_onnx_test.py new file mode 100644 index 0000000..81e3300 --- /dev/null +++ b/kimm/export/export_onnx_test.py @@ -0,0 +1,30 @@ +import pytest +from absl.testing import parameterized +from keras import backend +from keras.src import testing + +from kimm import export +from kimm import models + + +class ExportOnnxTest(testing.TestCase, parameterized.TestCase): + def get_model(self): + input_shape = [224, 224, 3] + model = models.MobileNet050V3Small(include_preprocessing=False) + return input_shape, model + + @pytest.mark.skipif( + backend.backend() != "tensorflow", # TODO: test torch + reason="Requires tensorflow or torch backend.", + ) + def test_export_onnx_use(self): + input_shape, model = self.get_model() + + temp_dir = self.get_temp_dir() + + if backend.backend() == "tensorflow": + export.export_onnx(model, input_shape, f"{temp_dir}/model.onnx") + elif backend.backend() == "torch": + export.export_onnx( + model, input_shape, f"{temp_dir}/model.onnx", use_nchw=False + ) diff --git a/kimm/export/export_tflite.py b/kimm/export/export_tflite.py new file mode 100644 index 0000000..a04351e --- /dev/null +++ b/kimm/export/export_tflite.py @@ -0,0 +1,72 @@ +import pathlib +import tempfile +import typing + +from keras import backend +from keras import layers +from keras import models +from keras.src.utils.module_utils import tensorflow as tf + +from kimm.models import BaseModel + + +def export_tflite( + model: BaseModel, + input_shape: typing.Union[int, typing.Sequence[int]], + export_path: typing.Union[str, pathlib.Path], + export_dtype: typing.Literal["float32", "float16", "int8"] = "float32", + representative_dataset: typing.Optional[typing.Iterator] = None, + batch_size: int = 1, +): + if backend.backend() != "tensorflow": + raise ValueError( + "Currently, `export_tflite` only supports tensorflow backend" + ) + if export_dtype not in ("float32", "float16", "int8"): + raise ValueError( + "`export_dtype` must be one of ('float32', 'float16', 'int8'). " + f"Received: export_dtype={export_dtype}" + ) + if export_dtype == "int8" and representative_dataset is None: + raise ValueError( + "For full integer quantization, a `representative_dataset` should " + "be specified." + ) + if isinstance(input_shape, int): + input_shape = [input_shape, input_shape, 3] + elif len(input_shape) == 2: + input_shape = [input_shape[0], input_shape[1], 3] + elif len(input_shape) == 3: + input_shape = input_shape + + # Fix input shape + inputs = layers.Input(shape=input_shape, batch_size=batch_size) + outputs = model(inputs, training=False) + model = models.Model(inputs, outputs) + + # Construct TFLiteConverter + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir, "temp_saved_model") + model.export(temp_path) + converter = tf.lite.TFLiteConverter.from_saved_model(str(temp_path)) + + # Configure converter + if export_dtype != "float32": + converter.optimizations = [tf.lite.Optimize.DEFAULT] + if export_dtype == "int8": + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS_INT8 + ] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + elif export_dtype == "float16": + converter.target_spec.supported_types = [tf.float16] + if representative_dataset is not None: + converter.representative_dataset = representative_dataset + + # Convert + tflite_model = converter.convert() + + # Export + with open(export_path, "wb") as f: + f.write(tflite_model) diff --git a/kimm/export/export_tflite_test.py b/kimm/export/export_tflite_test.py new file mode 100644 index 0000000..5604fe4 --- /dev/null +++ b/kimm/export/export_tflite_test.py @@ -0,0 +1,66 @@ +import pytest +from absl.testing import parameterized +from keras import backend +from keras import ops +from keras import random +from keras.src import testing + +from kimm import export +from kimm import models + + +class ExportTFLiteTest(testing.TestCase, parameterized.TestCase): + def get_model_and_representative_dataset(self): + input_shape = [224, 224, 3] + model = models.MobileNet050V3Small(include_preprocessing=False) + + def representative_dataset(): + for _ in range(10): + yield [ + ops.convert_to_numpy( + random.uniform([1, *input_shape], maxval=255.0) + ) + ] + + return input_shape, model, representative_dataset + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires tensorflow backend." + ) + def test_export_tflite_fp32(self): + (input_shape, model, _) = self.get_model_and_representative_dataset() + temp_dir = self.get_temp_dir() + + export.export_tflite( + model, input_shape, f"{temp_dir}/model_fp32.onnx", "float32" + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires tensorflow backend." + ) + def test_export_tflite_fp16(self): + (input_shape, model, _) = self.get_model_and_representative_dataset() + temp_dir = self.get_temp_dir() + + export.export_tflite( + model, input_shape, f"{temp_dir}/model_fp16.tflite", "float16" + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires tensorflow backend." + ) + def test_export_tflite_int8(self): + ( + input_shape, + model, + representative_dataset, + ) = self.get_model_and_representative_dataset() + temp_dir = self.get_temp_dir() + + export.export_tflite( + model, + input_shape, + f"{temp_dir}/model_int8.tflite", + "int8", + representative_dataset, + ) diff --git a/kimm/layers/attention_test.py b/kimm/layers/attention_test.py index 639f238..12a6992 100644 --- a/kimm/layers/attention_test.py +++ b/kimm/layers/attention_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras.src import testing @@ -5,6 +6,7 @@ class AttentionTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_attention_basic(self): self.run_layer_test( Attention, diff --git a/kimm/layers/layer_scale.py b/kimm/layers/layer_scale.py index 0265f31..7ef61c7 100644 --- a/kimm/layers/layer_scale.py +++ b/kimm/layers/layer_scale.py @@ -9,30 +9,47 @@ class LayerScale(layers.Layer): def __init__( self, - hidden_size: int, + axis: int = -1, initializer: Initializer = initializers.Constant(1e-5), name: str = "layer_scale", **kwargs, ): super().__init__(**kwargs) - self.hidden_size = hidden_size + self.axis = axis self.initializer = initializer self.name = name def build(self, input_shape): + if isinstance(self.axis, list): + shape = tuple([input_shape[dim] for dim in self.axis]) + else: + shape = (input_shape[self.axis],) + self.axis = [self.axis] self.gamma = self.add_weight( - [self.hidden_size], initializer=self.initializer, name="gamma" + shape, initializer=self.initializer, name="gamma" ) self.built = True def call(self, inputs, training=None, mask=None): - return ops.multiply(inputs, self.gamma) + inputs = ops.cast(inputs, self.compute_dtype) + + # Broadcasting only necessary for norm when the axis is not just + # the last dimension + input_shape = inputs.shape + ndims = len(inputs.shape) + broadcast_shape = [1] * ndims + for dim in self.axis: + broadcast_shape[dim] = input_shape[dim] + gamma = ops.reshape(self.gamma, broadcast_shape) + gamma = ops.cast(gamma, self.compute_dtype) + + return ops.multiply(inputs, gamma) def get_config(self): config = super().get_config() config.update( { - "hidden_size": self.hidden_size, + "axis": self.axis, "initializer": initializers.serialize(self.initializer), "name": self.name, } diff --git a/kimm/layers/layer_scale_test.py b/kimm/layers/layer_scale_test.py index 1a3fd89..6344923 100644 --- a/kimm/layers/layer_scale_test.py +++ b/kimm/layers/layer_scale_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras.src import testing @@ -5,10 +6,11 @@ class LayerScaleTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_layer_scale_basic(self): self.run_layer_test( LayerScale, - init_kwargs={"hidden_size": 10}, + init_kwargs={"axis": -1}, input_shape=(1, 10), expected_output_shape=(1, 10), expected_num_trainable_weights=1, diff --git a/kimm/layers/position_embedding_test.py b/kimm/layers/position_embedding_test.py index 9796e9f..f6fd8ff 100644 --- a/kimm/layers/position_embedding_test.py +++ b/kimm/layers/position_embedding_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import layers from keras.src import testing @@ -6,6 +7,7 @@ class PositionEmbeddingTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend def test_position_embedding_basic(self): self.run_layer_test( PositionEmbedding, @@ -18,6 +20,7 @@ def test_position_embedding_basic(self): supports_masking=False, ) + @pytest.mark.requires_trainable_backend def test_position_embedding_invalid_input_shape(self): inputs = layers.Input([3]) with self.assertRaisesRegex( diff --git a/kimm/models/base_model.py b/kimm/models/base_model.py index 06baa88..3a3a0f7 100644 --- a/kimm/models/base_model.py +++ b/kimm/models/base_model.py @@ -109,7 +109,7 @@ def determine_input_tensor( input_shape, default_size=default_size, min_size=min_size, - data_format="channels_last", # always channels_last + data_format=backend.image_data_format(), require_flatten=require_flatten or static_shape, weights=None, ) @@ -130,11 +130,16 @@ def build_preprocessing( ): if self._include_preprocessing is False: return inputs + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) if mode == "imagenet": # [0, 255] to [0, 1] and apply ImageNet mean and variance x = layers.Rescaling(scale=1.0 / 255.0)(inputs) x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + axis=channels_axis, + mean=[0.485, 0.456, 0.406], + variance=[0.229, 0.224, 0.225], )(x) elif mode == "0_1": # [0, 255] to [-1, 1] diff --git a/kimm/models/convmixer.py b/kimm/models/convmixer.py index d3d455a..6557b62 100644 --- a/kimm/models/convmixer.py +++ b/kimm/models/convmixer.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.models.base_model import BaseModel @@ -10,6 +11,8 @@ def apply_convmixer_block( inputs, output_channels, kernel_size, activation, name="convmixer_block" ): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs # Depthwise @@ -22,7 +25,7 @@ def apply_convmixer_block( name=f"{name}_0_fn_0_dwconv2d", )(x) x = layers.BatchNormalization( - momentum=0.9, epsilon=1e-5, name=f"{name}_0_fn_2" + axis=channels_axis, momentum=0.9, epsilon=1e-5, name=f"{name}_0_fn_2" )(x) x = layers.Add()([x, inputs]) @@ -35,9 +38,9 @@ def apply_convmixer_block( use_bias=True, name=f"{name}_1_conv2d", )(x) - x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=f"{name}_3")( - x - ) + x = layers.BatchNormalization( + axis=channels_axis, momentum=0.9, epsilon=1e-5, name=f"{name}_3" + )(x) return x @@ -53,9 +56,12 @@ def __init__( **kwargs, ): kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) - input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -78,7 +84,7 @@ def __init__( name="stem_conv2d", )(x) x = layers.BatchNormalization( - momentum=0.9, epsilon=1e-5, name="stem_bn" + axis=channels_axis, momentum=0.9, epsilon=1e-5, name="stem_bn" )(x) features["STEM"] = x diff --git a/kimm/models/convnext.py b/kimm/models/convnext.py index f0edf67..82e3a47 100644 --- a/kimm/models/convnext.py +++ b/kimm/models/convnext.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import initializers from keras import layers @@ -21,7 +22,9 @@ def apply_convnext_block( use_grn=False, name="convnext_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] + hidden_channels = int(mlp_ratio * output_channels) x = inputs shortcut = inputs @@ -42,7 +45,9 @@ def apply_convnext_block( use_bias=True, name=f"{name}_conv_dw_dwconv2d", )(x) - x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm")(x) + x = layers.LayerNormalization( + axis=channels_axis, epsilon=1e-6, name=f"{name}_norm" + )(x) # MLP x = apply_mlp_block( @@ -57,7 +62,9 @@ def apply_convnext_block( # LayerScale x = kimm_layers.LayerScale( - output_channels, initializers.Constant(1e-6), name=f"{name}_layerscale" + axis=channels_axis, + initializer=initializers.Constant(1e-6), + name=f"{name}_layerscale", )(x) # Downsample @@ -85,14 +92,16 @@ def apply_convnext_stage( use_grn=False, name="convnext_stage", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] + x = inputs # Downsample if input_channels != output_channels or strides > 1: ds_ks = 2 if strides > 1 else 1 x = layers.LayerNormalization( - epsilon=1e-6, name=f"{name}_downsample_0" + axis=channels_axis, epsilon=1e-6, name=f"{name}_downsample_0" )(x) x = layers.Conv2D( output_channels, @@ -136,9 +145,12 @@ def __init__( **kwargs, ): kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) - input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -159,7 +171,9 @@ def __init__( use_bias=True, name="stem_0_conv2d", )(x) - x = layers.LayerNormalization(epsilon=1e-6, name="stem_1")(x) + x = layers.LayerNormalization( + axis=channels_axis, epsilon=1e-6, name="stem_1" + )(x) features["STEM_S4"] = x # Blocks (4 stages) @@ -197,7 +211,9 @@ def __init__( def build_top(self, inputs, classes, classifier_activation, dropout_rate): x = inputs x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) - x = layers.LayerNormalization(epsilon=1e-6, name="head_norm")(x) + x = layers.LayerNormalization(axis=-1, epsilon=1e-6, name="head_norm")( + x + ) x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) x = layers.Dense( classes, activation=classifier_activation, name="classifier" diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py index 3993cc3..8f84064 100644 --- a/kimm/models/densenet.py +++ b/kimm/models/densenet.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.blocks import apply_conv2d_block @@ -11,9 +12,11 @@ def apply_dense_layer( inputs, growth_rate, expansion_ratio=4.0, name="dense_layer" ): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs x = layers.BatchNormalization( - momentum=0.9, epsilon=1e-5, name=f"{name}_norm1" + axis=channels_axis, momentum=0.9, epsilon=1e-5, name=f"{name}_norm1" )(x) x = layers.ReLU()(x) x = apply_conv2d_block( @@ -33,11 +36,12 @@ def apply_dense_layer( def apply_dense_block( inputs, num_layers, growth_rate, expansion_ratio=4.0, name="dense_block" ): - x = inputs + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs features = [x] for i in range(num_layers): - new_features = layers.Concatenate()(features) + new_features = layers.Concatenate(axis=channels_axis)(features) new_features = apply_dense_layer( new_features, growth_rate, @@ -45,22 +49,25 @@ def apply_dense_block( name=f"{name}_denselayer{i + 1}", ) features.append(new_features) - x = layers.Concatenate()(features) + x = layers.Concatenate(axis=channels_axis)(features) return x def apply_dense_transition_block( inputs, output_channels, name="dense_transition_block" ): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 x = inputs x = layers.BatchNormalization( - momentum=0.9, epsilon=1e-5, name=f"{name}_norm" + axis=channels_axis, momentum=0.9, epsilon=1e-5, name=f"{name}_norm" )(x) x = layers.ReLU()(x) x = layers.Conv2D( output_channels, 1, 1, "same", use_bias=False, name=f"{name}_conv" )(x) - x = layers.AveragePooling2D(2, 2, name=f"{name}_pool")(x) + x = layers.AveragePooling2D( + 2, 2, data_format=backend.image_data_format(), name=f"{name}_pool" + )(x) return x @@ -78,9 +85,12 @@ def __init__( **kwargs, ): kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) - input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -127,7 +137,10 @@ def __init__( # Final batch norm x = layers.BatchNormalization( - momentum=0.9, epsilon=1e-5, name="features_norm5" + axis=channels_axis, + momentum=0.9, + epsilon=1e-5, + name="features_norm5", )(x) x = layers.ReLU()(x) diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index ab6f0b4..6776fe8 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -2,6 +2,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.blocks import apply_conv2d_block @@ -91,7 +92,8 @@ def apply_edge_residual_block( padding=None, name="edge_residual_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] hidden_channels = make_divisible(input_channels * expansion_ratio) has_skip = strides == 1 and input_channels == output_channels @@ -187,6 +189,10 @@ def __init__( input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -241,7 +247,7 @@ def __init__( x, c, k, 1, s, se, se_activation=activation, **_kwargs ) elif block_type == "ir": - se_c = x.shape[-1] + se_c = x.shape[channels_axis] x = apply_inverted_residual_block( x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs ) diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index b932829..486701c 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import layers from keras import ops @@ -64,6 +65,7 @@ def apply_ghost_block( activation="relu", name="ghost_block", ): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 hidden_channels_1 = int(ops.ceil(output_channels / expand_ratio)) hidden_channels_2 = int(hidden_channels_1 * (expand_ratio - 1.0)) @@ -85,8 +87,11 @@ def apply_ghost_block( use_depthwise=True, name=f"{name}_cheap_operation", ) - out = layers.Concatenate(name=f"{name}")([x1, x2]) - return out[..., :output_channels] + out = layers.Concatenate(axis=channels_axis, name=f"{name}")([x1, x2]) + if channels_axis == -1: + return out[..., :output_channels] + else: + return out[:, :output_channels, ...] def apply_ghost_block_v2( @@ -99,6 +104,12 @@ def apply_ghost_block_v2( activation="relu", name="ghost_block_v2", ): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + if backend.image_data_format() == "channels_last": + output_axis = (-3, -2) + else: + output_axis = (-2, -1) + hidden_channels_1 = int(ops.ceil(output_channels / expand_ratio)) hidden_channels_2 = int(hidden_channels_1 * (expand_ratio - 1.0)) @@ -121,9 +132,13 @@ def apply_ghost_block_v2( use_depthwise=True, name=f"{name}_cheap_operation", ) - out = layers.Concatenate(name=f"{name}_concat")([x1, x2]) + out = layers.Concatenate(axis=channels_axis, name=f"{name}_concat")( + [x1, x2] + ) - residual = layers.AveragePooling2D(2, 2, name=f"{name}_avg_pool")(residual) + residual = layers.AveragePooling2D( + 2, 2, data_format=backend.image_data_format(), name=f"{name}_avg_pool" + )(residual) residual = apply_conv2d_block( residual, output_channels, @@ -149,10 +164,13 @@ def apply_ghost_block_v2( ) residual = layers.Activation("sigmoid", name=f"{name}_gate")(residual) # TODO: support dynamic shape - residual = layers.Resizing(out.shape[-3], out.shape[-2], "nearest")( - residual - ) - out = out[..., :output_channels] + residual = layers.Resizing( + out.shape[output_axis[0]], out.shape[output_axis[1]], "nearest" + )(residual) + if channels_axis == -1: + out = out[..., :output_channels] + else: + out = out[:, :output_channels, ...] out = layers.Multiply(name=name)([out, residual]) return out @@ -168,7 +186,8 @@ def apply_ghost_bottleneck( use_attention=False, # GhostNetV2 name="ghost_bottlenect", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] has_se = se_ratio is not None and se_ratio > 0.0 x = inputs diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index 55a3043..6dcb1f8 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -2,6 +2,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.blocks import apply_conv2d_block @@ -14,6 +15,8 @@ def apply_inception_a_block(inputs, pool_channels, name="inception_a_block"): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs branch1x1 = _apply_conv2d_block(x, 64, 1, 1, name=f"{name}_branch1x1") @@ -34,7 +37,9 @@ def apply_inception_a_block(inputs, pool_channels, name="inception_a_block"): ) branch_pool = layers.ZeroPadding2D(1)(x) - branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = layers.AveragePooling2D( + 3, 1, data_format=backend.image_data_format() + )(branch_pool) branch_pool = _apply_conv2d_block( branch_pool, pool_channels, @@ -43,11 +48,15 @@ def apply_inception_a_block(inputs, pool_channels, name="inception_a_block"): activation="relu", name=f"{name}_branch_pool", ) - x = layers.Concatenate()([branch1x1, branch5x5, branch3x3dbl, branch_pool]) + x = layers.Concatenate(axis=channels_axis)( + [branch1x1, branch5x5, branch3x3dbl, branch_pool] + ) return x def apply_inception_b_block(inputs, name="incpetion_b_block"): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs branch3x3 = _apply_conv2d_block(x, 384, 3, 2, name=f"{name}_branch3x3") @@ -63,14 +72,18 @@ def apply_inception_b_block(inputs, name="incpetion_b_block"): ) branch_pool = layers.MaxPooling2D(3, 2, name=f"{name}_branch_pool")(x) - x = layers.Concatenate()([branch3x3, branch3x3dbl, branch_pool]) + x = layers.Concatenate(axis=channels_axis)( + [branch3x3, branch3x3dbl, branch_pool] + ) return x def apply_inception_c_block( inputs, branch7x7_channels, name="inception_c_block" ): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 c7 = branch7x7_channels + x = inputs branch1x1 = _apply_conv2d_block(x, 192, 1, 1, name=f"{name}_branch1x1") @@ -105,15 +118,21 @@ def apply_inception_c_block( ) branch_pool = layers.ZeroPadding2D(1)(x) - branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = layers.AveragePooling2D( + 3, 1, data_format=backend.image_data_format() + )(branch_pool) branch_pool = _apply_conv2d_block( branch_pool, 192, 1, 1, name=f"{name}_branch_pool" ) - x = layers.Concatenate()([branch1x1, branch7x7, branch7x7dbl, branch_pool]) + x = layers.Concatenate(axis=channels_axis)( + [branch1x1, branch7x7, branch7x7dbl, branch_pool] + ) return x def apply_inception_d_block(inputs, name="inception_d_block"): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs branch3x3 = _apply_conv2d_block(x, 192, 1, 1, name=f"{name}_branch3x3_1") @@ -135,11 +154,15 @@ def apply_inception_d_block(inputs, name="inception_d_block"): ) branch_pool = layers.MaxPooling2D(3, 2)(x) - x = layers.Concatenate()([branch3x3, branch7x7x3, branch_pool]) + x = layers.Concatenate(axis=channels_axis)( + [branch3x3, branch7x7x3, branch_pool] + ) return x def apply_inception_e_block(inputs, name="inception_e_block"): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs branch1x1 = _apply_conv2d_block(x, 320, 1, 1, name=f"{name}_branch1x1") @@ -153,7 +176,7 @@ def apply_inception_e_block(inputs, name="inception_e_block"): branch3x3, 384, (3, 1), 1, padding=None, name=f"{name}_branch3x3_2b" ), ] - branch3x3 = layers.Concatenate()(branch3x3) + branch3x3 = layers.Concatenate(axis=channels_axis)(branch3x3) branch3x3dbl = _apply_conv2d_block( x, 448, 1, 1, name=f"{name}_branch3x3dbl_1" @@ -179,21 +202,27 @@ def apply_inception_e_block(inputs, name="inception_e_block"): name=f"{name}_branch3x3dbl_3b", ), ] - branch3x3dbl = layers.Concatenate()(branch3x3dbl) + branch3x3dbl = layers.Concatenate(axis=channels_axis)(branch3x3dbl) branch_pool = layers.ZeroPadding2D(1)(x) - branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = layers.AveragePooling2D( + 3, 1, data_format=backend.image_data_format() + )(branch_pool) branch_pool = _apply_conv2d_block( branch_pool, 192, 1, 1, name=f"{name}_branch_pool" ) - x = layers.Concatenate()([branch1x1, branch3x3, branch3x3dbl, branch_pool]) + x = layers.Concatenate(axis=channels_axis)( + [branch1x1, branch3x3, branch3x3dbl, branch_pool] + ) return x def apply_inception_aux_block(inputs, classes, name="inception_aux_block"): x = inputs - x = layers.AveragePooling2D(5, 3)(x) + x = layers.AveragePooling2D(5, 3, data_format=backend.image_data_format())( + x + ) x = _apply_conv2d_block(x, 128, 1, 1, name=f"{name}_conv0") x = _apply_conv2d_block(x, 768, 5, 1, name=f"{name}_conv1") x = layers.GlobalAveragePooling2D()(x) diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index c065d5d..583706d 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -2,6 +2,7 @@ import typing import keras +from keras import backend from keras import layers from keras import ops @@ -95,7 +96,8 @@ def apply_mobilevit_block( fusion: bool = True, name="mobilevit_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] transformer_dim = transformer_dim or make_divisible( input_channels * expansion_ratio ) @@ -115,7 +117,11 @@ def apply_mobilevit_block( transformer_dim, 1, use_bias=False, name=f"{name}_conv_1x1" )(x) + # TODO: natively support channels_first # Unfold (feature map -> patches) + if backend.image_data_format() == "channels_first": + x = ops.transpose(x, [0, 2, 3, 1]) + h, w, c = x.shape[-3], x.shape[-2], x.shape[-1] x = unfold(x, patch_size) @@ -133,11 +139,15 @@ def apply_mobilevit_block( activation=transformer_activation, name=f"{name}_transformer_{i}", ) - x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm")(x) + x = layers.LayerNormalization(axis=-1, epsilon=1e-6, name=f"{name}_norm")(x) # Fold (patch -> feature map) x = fold(x, h, w, c, patch_size) + # TODO: natively support channels_first + if backend.image_data_format() == "channels_first": + x = ops.transpose(x, [0, 3, 1, 2]) + x = apply_conv2d_block( x, output_channels, @@ -147,7 +157,7 @@ def apply_mobilevit_block( name=f"{name}_conv_proj", ) if fusion: - x = layers.Concatenate()([inputs, x]) + x = layers.Concatenate(axis=channels_axis)([inputs, x]) x = apply_conv2d_block( x, @@ -192,6 +202,7 @@ def __init__( input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs, 256) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index c5bfddf..73b418d 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -1,10 +1,12 @@ -import cv2 import keras import pytest +import tensorflow as tf from absl.testing import parameterized +from keras import backend from keras import models from keras import ops from keras import random +from keras import utils from keras.applications.imagenet_utils import decode_predictions from keras.src import testing @@ -344,18 +346,59 @@ class ModelTest(testing.TestCase, parameterized.TestCase): + @classmethod + def setUpClass(cls): + cls.original_image_data_format = backend.image_data_format() + + @classmethod + def tearDownClass(cls): + backend.set_image_data_format(cls.original_image_data_format) + @parameterized.named_parameters(MODEL_CONFIGS) - def test_model_base( + def test_model_base_channels_last( self, model_class, image_size, features, weights="imagenet" ): - # TODO: test the correctness of the real image + backend.set_image_data_format("channels_last") + model = model_class(weights=weights) + image_path = keras.utils.get_file( + "african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png" + ) + # preprocessing + image = utils.load_img(image_path, target_size=(image_size, image_size)) + image = utils.img_to_array(image, data_format="channels_last") + x = ops.convert_to_tensor(image) + x = ops.expand_dims(x, axis=0) + + y = model(x, training=False) + + if weights == "imagenet": + names = [p[1] for p in decode_predictions(y)[0]] + # Test correct label is in top 3 (weak correctness test). + self.assertIn("African_elephant", names[:3]) + elif weights is None: + self.assertEqual(list(y.shape), [1, 1000]) + + @parameterized.named_parameters(MODEL_CONFIGS) + def test_model_base_channels_first( + self, model_class, image_size, features, weights="imagenet" + ): + if ( + len(tf.config.list_physical_devices("GPU")) == 0 + and backend.backend() == "tensorflow" + ): + self.skipTest( + "Conv2D doesn't support channels_first using CPU with " + "tensorflow backend" + ) + + backend.set_image_data_format("channels_first") model = model_class(weights=weights) image_path = keras.utils.get_file( "african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png" ) # preprocessing - image = cv2.imread(image_path) - image = cv2.resize(image, (image_size, image_size)) + image = utils.load_img(image_path, target_size=(image_size, image_size)) + image = utils.img_to_array(image, data_format="channels_first") x = ops.convert_to_tensor(image) x = ops.expand_dims(x, axis=0) @@ -372,6 +415,7 @@ def test_model_base( def test_model_feature_extractor( self, model_class, image_size, features, weights="imagenet" ): + backend.set_image_data_format("channels_last") x = random.uniform([1, image_size, image_size, 3]) * 255.0 model = model_class(weights=None, feature_extractor=True) @@ -390,6 +434,7 @@ def test_model_feature_extractor( def test_model_serialization( self, model_class, image_size, features, weights="imagenet" ): + backend.set_image_data_format("channels_last") x = random.uniform([1, image_size, image_size, 3]) * 255.0 temp_dir = self.get_temp_dir() model1 = model_class(weights=None) diff --git a/kimm/models/regnet.py b/kimm/models/regnet.py index 6ab337e..5712596 100644 --- a/kimm/models/regnet.py +++ b/kimm/models/regnet.py @@ -2,6 +2,7 @@ import keras import numpy as np +from keras import backend from keras import layers from kimm.blocks import apply_conv2d_block @@ -84,7 +85,8 @@ def apply_bottleneck_block( linear_out: bool = False, name="bottleneck_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] expansion_channels = int(round(output_channels * expansion_ratio)) groups = expansion_channels // group_size diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 298dea2..1a05129 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.blocks import apply_conv2d_block @@ -15,7 +16,8 @@ def apply_basic_block( activation="relu", name="basic_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] shortcut = inputs x = inputs x = apply_conv2d_block( @@ -58,7 +60,8 @@ def apply_bottleneck_block( activation="relu", name="bottleneck_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] expansion = 4 shortcut = inputs x = inputs diff --git a/kimm/models/vgg.py b/kimm/models/vgg.py index 069426e..fbcfa12 100644 --- a/kimm/models/vgg.py +++ b/kimm/models/vgg.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.models import BaseModel @@ -132,6 +133,10 @@ def __init__(self, config: typing.Union[str, typing.List], **kwargs): input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -166,6 +171,7 @@ def __init__(self, config: typing.Union[str, typing.List], **kwargs): name=f"features_{current_block_idx}conv2d", )(x) x = layers.BatchNormalization( + axis=channels_axis, momentum=0.9, epsilon=1e-5, name=f"features_{current_block_idx + 1}", diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index 82f757b..ead1e22 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -1,7 +1,9 @@ import typing import keras +from keras import backend from keras import layers +from keras import ops from kimm import layers as kimm_layers from kimm.blocks import apply_transformer_block @@ -45,7 +47,7 @@ def __init__( # Prepare feature extraction features = {} - # patch embedding + # Patch embedding x = layers.Conv2D( embed_dim, kernel_size=patch_size, @@ -54,6 +56,11 @@ def __init__( use_bias=True, name="patch_embed_conv", )(x) + + # TODO: natively support channels_first + if backend.image_data_format() == "channels_first": + x = ops.transpose(x, [0, 2, 3, 1]) + x = layers.Reshape((-1, embed_dim))(x) x = kimm_layers.PositionEmbedding(name="postition_embedding")(x) features["EMBEDDING"] = x diff --git a/kimm/models/xception.py b/kimm/models/xception.py index b9f743d..5e426e7 100644 --- a/kimm/models/xception.py +++ b/kimm/models/xception.py @@ -1,6 +1,7 @@ import typing import keras +from keras import backend from keras import layers from kimm.models import BaseModel @@ -16,7 +17,8 @@ def apply_xception_block( grow_first=True, name="xception_block", ): - input_channels = inputs.shape[-1] + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] x = inputs residual = inputs @@ -38,6 +40,7 @@ def apply_xception_block( name=f"{name}_rep_{current_layer_idx}", )(x) x = layers.BatchNormalization( + axis=channels_axis, name=f"{name}_rep_{current_layer_idx + 1}", )(x) current_layer_idx += 2 @@ -57,7 +60,7 @@ def apply_xception_block( name=f"{name}_skipconv2d", )(residual) residual = layers.BatchNormalization( - name=f"{name}_skipbn", + axis=channels_axis, name=f"{name}_skipbn" )(residual) x = layers.Add()([x, residual]) @@ -76,6 +79,10 @@ def __init__(self, **kwargs): input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -91,10 +98,10 @@ def __init__(self, **kwargs): # Stem x = layers.Conv2D(32, 3, 2, use_bias=False, name="conv1")(x) - x = layers.BatchNormalization(name="bn1")(x) + x = layers.BatchNormalization(axis=channels_axis, name="bn1")(x) x = layers.ReLU()(x) x = layers.Conv2D(64, 3, 1, use_bias=False, name="conv2")(x) - x = layers.BatchNormalization(name="bn2")(x) + x = layers.BatchNormalization(axis=channels_axis, name="bn2")(x) x = layers.ReLU()(x) features["STEM_S2"] = x @@ -123,13 +130,13 @@ def __init__(self, **kwargs): x = layers.SeparableConv2D( 1536, 3, 1, "same", use_bias=False, name="conv3" )(x) - x = layers.BatchNormalization(name="bn3")(x) + x = layers.BatchNormalization(axis=channels_axis, name="bn3")(x) x = layers.ReLU()(x) x = layers.SeparableConv2D( 2048, 3, 1, "same", use_bias=False, name="conv4" )(x) - x = layers.BatchNormalization(name="bn4")(x) + x = layers.BatchNormalization(axis=channels_axis, name="bn4")(x) x = layers.ReLU()(x) features["BLOCK3_S32"] = x diff --git a/pyproject.toml b/pyproject.toml index d2eed04..723c1cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,12 @@ dependencies = ["keras"] [project.optional-dependencies] tests = [ - "opencv-python", + # export + "tf2onnx", + "onnx", + "onnxoptimizer", + "onnxsim", + # linter and formatter "isort", "ruff", "black",