diff --git a/README.md b/README.md index 479ce41..a13ebcc 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ # Keras Image Models +- [Latest Updates](#latest-updates) - [Introduction](#introduction) - [Usage](#usage) - [Installation](#installation) @@ -24,6 +25,13 @@ - [License](#license) - [Acknowledgements](#acknowledgements) +## Latest Updates + +2024/05/29: + +- Merge reparameterizable layers into 1 `ReparameterizableConv2D` +- Add `GhostNetV3*` from [huawei-noah/Efficient-AI-Backbones](https://github.com/huawei-noah/Efficient-AI-Backbones) + ## Introduction **K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner. @@ -154,6 +162,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io |EfficientNetV2|[ICML 2021](https://arxiv.org/abs/2104.00298)|`timm`|`kimm.models.EfficientNetV2*`| |GhostNet|[CVPR 2020](https://arxiv.org/abs/1911.11907)|`timm`|`kimm.models.GhostNet*`| |GhostNetV2|[NeurIPS 2022](https://arxiv.org/abs/2211.12905)|`timm`|`kimm.models.GhostNetV2*`| +|GhostNetV3|[arXiv 2024](https://arxiv.org/abs/2404.11202)|`github`|`kimm.models.GhostNetV3*`| |HGNet||`timm`|`kimm.models.HGNet*`| |HGNetV2||`timm`|`kimm.models.HGNetV2*`| |InceptionNeXt|[arXiv 2023](https://arxiv.org/abs/2303.16900)|`timm`|`kimm.models.InceptionNeXt*`| diff --git a/kimm/__init__.py b/kimm/__init__.py index 59921a7..9c03477 100644 --- a/kimm/__init__.py +++ b/kimm/__init__.py @@ -13,4 +13,4 @@ from kimm._src.utils.model_registry import list_models from kimm._src.version import version -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/kimm/_src/layers/rep_conv2d.py b/kimm/_src/layers/rep_conv2d.py deleted file mode 100644 index cc223d9..0000000 --- a/kimm/_src/layers/rep_conv2d.py +++ /dev/null @@ -1,287 +0,0 @@ -import keras -import numpy as np -from keras import Sequential -from keras import layers -from keras import ops -from keras.src.backend import standardize_data_format -from keras.src.layers import Layer -from keras.src.utils.argument_validation import standardize_tuple - -from kimm._src.kimm_export import kimm_export - - -@kimm_export(parent_path=["kimm.layers"]) -@keras.saving.register_keras_serializable(package="kimm") -class RepConv2D(Layer): - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding=None, - has_skip: bool = True, - reparameterized: bool = False, - data_format=None, - activation=None, - **kwargs, - ): - super().__init__(**kwargs) - self.filters = filters - self.kernel_size = standardize_tuple(kernel_size, 2, "kernel_size") - self.strides = standardize_tuple(strides, 2, "strides") - self.padding = padding - self.has_skip = has_skip - self._reparameterized = reparameterized - self.data_format = standardize_data_format(data_format) - self.activation = activation - - if self.kernel_size[0] != self.kernel_size[1]: - raise ValueError( - "The value of kernel_size must be the same. " - f"Received: kernel_size={kernel_size}" - ) - if self.strides[0] != self.strides[1]: - raise ValueError( - "The value of strides must be the same. " - f"Received: strides={strides}" - ) - if has_skip is True and (self.strides[0] != 1 or self.strides[1] != 1): - raise ValueError( - "strides must be `1` when `has_skip=True`. " - f"Received: has_skip={has_skip}, strides={strides}" - ) - - self.zero_padding = layers.Identity(dtype=self.dtype_policy) - if padding is None: - padding = "same" - if self.strides[0] > 1: - padding = "valid" - self.zero_padding = layers.ZeroPadding2D( - (self.kernel_size[0] // 2, self.kernel_size[1] // 2), - data_format=self.data_format, - dtype=self.dtype_policy, - name=f"{self.name}_pad", - ) - self.padding = padding - else: - self.padding = padding - - channel_axis = -1 if self.data_format == "channels_last" else -3 - if self._reparameterized: - self.rep_conv2d = layers.Conv2D( - self.filters, - self.kernel_size, - self.strides, - self.padding, - data_format=self.data_format, - use_bias=True, - dtype=self.dtype_policy, - name=f"{self.name}_reparam_conv", - ) - self.identity = None - self.conv_kxk = None - self.conv_1x1 = None - else: - self.rep_conv2d = None - if self.has_skip: - self.identity = layers.BatchNormalization( - axis=channel_axis, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype_policy, - name=f"{self.name}_identity", - ) - else: - self.identity = None - self.conv_kxk = Sequential( - [ - layers.Conv2D( - self.filters, - self.kernel_size, - self.strides, - padding=self.padding, - data_format=self.data_format, - use_bias=False, - dtype=self.dtype_policy, - ), - layers.BatchNormalization( - axis=channel_axis, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype_policy, - ), - ], - name=f"{self.name}_conv_kxk", - ) - self.conv_1x1 = Sequential( - [ - layers.Conv2D( - self.filters, - 1, - self.strides, - padding=self.padding, - data_format=self.data_format, - use_bias=False, - dtype=self.dtype_policy, - ), - layers.BatchNormalization( - axis=channel_axis, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype_policy, - ), - ], - name=f"{self.name}_conv_1x1", - ) - - if activation is None: - self.act = layers.Identity(dtype=self.dtype_policy) - else: - self.act = layers.Activation(activation, dtype=self.dtype_policy) - - # Internal parameters for `_get_reparameterized_weights_from_layer` - self._input_channels = None - self._rep_kernel_shape = None - - # Attach extra layers - self.extra_layers = [] - if self.rep_conv2d is not None: - self.extra_layers.append(self.rep_conv2d) - if self.identity is not None: - self.extra_layers.append(self.identity) - if self.conv_kxk is not None: - self.extra_layers.append(self.conv_kxk) - if self.conv_1x1 is not None: - self.extra_layers.append(self.conv_1x1) - self.extra_layers.append(self.act) - - def build(self, input_shape): - channel_axis = -1 if self.data_format == "channels_last" else -3 - - if isinstance(self.zero_padding, layers.ZeroPadding2D): - padded_shape = self.zero_padding.compute_output_shape(input_shape) - else: - padded_shape = input_shape - - if self.rep_conv2d is not None: - self.rep_conv2d.build(padded_shape) - if self.identity is not None: - self.identity.build(input_shape) - if self.conv_kxk is not None: - self.conv_kxk.build(padded_shape) - if self.conv_1x1 is not None: - self.conv_1x1.build(input_shape) - - # Update internal parameters - self._input_channels = input_shape[channel_axis] - if self.conv_kxk is not None: - self._rep_kernel_shape = self.conv_kxk.layers[0].kernel.shape - - self.built = True - - def call(self, inputs, **kwargs): - x = ops.cast(inputs, self.compute_dtype) - padded_x = self.zero_padding(x) - - # Shortcut for reparameterized mode - if self._reparameterized: - return self.act(self.rep_conv2d(padded_x, **kwargs)) - - if self.identity is None: - x = self.conv_1x1(x, **kwargs) + self.conv_kxk(padded_x, **kwargs) - else: - identity = self.identity(x, **kwargs) - x = self.conv_1x1(x, **kwargs) + self.conv_kxk(padded_x, **kwargs) - x = x + identity - return self.act(x) - - def get_config(self): - config = super().get_config() - config.update( - { - "filters": self.filters, - "kernel_size": self.kernel_size, - "strides": self.strides, - "padding": self.padding, - "has_skip": self.has_skip, - "reparameterized": self._reparameterized, - "data_format": self.data_format, - "activation": self.activation, - "name": self.name, - } - ) - return config - - def _get_reparameterized_weights_from_layer(self, layer): - if isinstance(layer, Sequential): - if not isinstance(layer.layers[0], layers.Conv2D): - raise ValueError - if not isinstance(layer.layers[1], layers.BatchNormalization): - raise ValueError - kernel = ops.convert_to_numpy(layer.layers[0].kernel) - gamma = ops.convert_to_numpy(layer.layers[1].gamma) - beta = ops.convert_to_numpy(layer.layers[1].beta) - running_mean = ops.convert_to_numpy(layer.layers[1].moving_mean) - running_var = ops.convert_to_numpy(layer.layers[1].moving_variance) - eps = layer.layers[1].epsilon - elif isinstance(layer, layers.BatchNormalization): - if self._rep_kernel_shape is None: - raise ValueError( - "Remember to build the layer before performing" - "reparameterization. Failed to get valid " - "`self._rep_kernel_shape`." - ) - # Calculate identity tensor - kernel_value = ops.convert_to_numpy( - ops.zeros(self._rep_kernel_shape) - ) - kernel_value = kernel_value.copy() - for i in range(self._input_channels): - kernel_value[ - self.kernel_size[0] // 2, self.kernel_size[1] // 2, i, i - ] = 1 - kernel = kernel_value - gamma = ops.convert_to_numpy(layer.gamma) - beta = ops.convert_to_numpy(layer.beta) - running_mean = ops.convert_to_numpy(layer.moving_mean) - running_var = ops.convert_to_numpy(layer.moving_variance) - eps = layer.epsilon - - # Use float64 for better precision - kernel = kernel.astype("float64") - gamma = gamma.astype("float64") - beta = beta.astype("float64") - running_var = running_var.astype("float64") - running_var = running_var.astype("float64") - - std = np.sqrt(running_var + eps) - t = np.reshape(gamma / std, [1, 1, 1, -1]) - return kernel * t, beta - running_mean * gamma / std - - def get_reparameterized_weights(self): - kernel_1x1 = 0.0 - bias_1x1 = 0.0 - if self.conv_1x1 is not None: - kernel_1x1, bias_1x1 = self._get_reparameterized_weights_from_layer( - self.conv_1x1 - ) - pad = self.kernel_size[0] // 2 - kernel_1x1 = np.pad( - kernel_1x1, [[pad, pad], [pad, pad], [0, 0], [0, 0]] - ) - - kernel_identity = 0.0 - bias_identity = 0.0 - if self.identity is not None: - ( - kernel_identity, - bias_identity, - ) = self._get_reparameterized_weights_from_layer(self.identity) - - kernel_conv, bias_conv = self._get_reparameterized_weights_from_layer( - self.conv_kxk - ) - - kernel_final = kernel_conv + kernel_1x1 + kernel_identity - bias_final = bias_conv + bias_1x1 + bias_identity - return kernel_final, bias_final diff --git a/kimm/_src/layers/rep_conv2d_test.py b/kimm/_src/layers/rep_conv2d_test.py deleted file mode 100644 index 4712d17..0000000 --- a/kimm/_src/layers/rep_conv2d_test.py +++ /dev/null @@ -1,134 +0,0 @@ -import pytest -from absl.testing import parameterized -from keras import backend -from keras import random -from keras.src import testing - -from kimm._src.layers.rep_conv2d import RepConv2D - -TEST_CASES = [ - { - "filters": 16, - "kernel_size": 3, - "has_skip": True, - "data_format": "channels_last", - "input_shape": (1, 4, 4, 16), - "output_shape": (1, 4, 4, 16), - "num_trainable_weights": 8, - "num_non_trainable_weights": 6, - }, - { - "filters": 16, - "kernel_size": 3, - "has_skip": False, - "data_format": "channels_last", - "input_shape": (1, 4, 4, 8), - "output_shape": (1, 4, 4, 16), - "num_trainable_weights": 6, - "num_non_trainable_weights": 4, - }, - { - "filters": 16, - "kernel_size": 5, - "has_skip": True, - "data_format": "channels_last", - "input_shape": (1, 4, 4, 16), - "output_shape": (1, 4, 4, 16), - "num_trainable_weights": 8, - "num_non_trainable_weights": 6, - }, - { - "filters": 16, - "kernel_size": 3, - "has_skip": True, - "data_format": "channels_first", - "input_shape": (1, 16, 4, 4), - "output_shape": (1, 16, 4, 4), - "num_trainable_weights": 8, - "num_non_trainable_weights": 6, - }, -] - - -class RepConv2DTest(testing.TestCase, parameterized.TestCase): - @parameterized.parameters(TEST_CASES) - @pytest.mark.requires_trainable_backend - def test_basic( - self, - filters, - kernel_size, - has_skip, - data_format, - input_shape, - output_shape, - num_trainable_weights, - num_non_trainable_weights, - ): - if ( - backend.backend() == "tensorflow" - and data_format == "channels_first" - ): - self.skipTest( - "Conv2D in tensorflow backend with 'channels_first' is limited " - "to be supported" - ) - self.run_layer_test( - RepConv2D, - init_kwargs={ - "filters": filters, - "kernel_size": kernel_size, - "has_skip": has_skip, - "data_format": data_format, - }, - input_shape=input_shape, - expected_output_shape=output_shape, - expected_num_trainable_weights=num_trainable_weights, - expected_num_non_trainable_weights=num_non_trainable_weights, - expected_num_losses=0, - supports_masking=False, - ) - - @parameterized.parameters(TEST_CASES) - def test_get_reparameterized_weights( - self, - filters, - kernel_size, - has_skip, - data_format, - input_shape, - output_shape, - num_trainable_weights, - num_non_trainable_weights, - ): - if ( - backend.backend() == "tensorflow" - and data_format == "channels_first" - ): - self.skipTest( - "Conv2D in tensorflow backend with 'channels_first' is limited " - "to be supported" - ) - layer = RepConv2D( - filters=filters, - kernel_size=kernel_size, - has_skip=has_skip, - data_format=data_format, - ) - layer.build(input_shape) - reparameterized_layer = RepConv2D( - filters=filters, - kernel_size=kernel_size, - has_skip=has_skip, - reparameterized=True, - data_format=data_format, - ) - reparameterized_layer.build(input_shape) - x = random.uniform(input_shape) - - kernel, bias = layer.get_reparameterized_weights() - reparameterized_layer.rep_conv2d.kernel.assign(kernel) - reparameterized_layer.rep_conv2d.bias.assign(bias) - y1 = layer(x, training=False) - y2 = reparameterized_layer(x, training=False) - - self.assertAllClose(y1, y2, atol=1e-3) diff --git a/kimm/_src/layers/mobile_one_conv2d.py b/kimm/_src/layers/reparameterizable_conv2d.py similarity index 62% rename from kimm/_src/layers/mobile_one_conv2d.py rename to kimm/_src/layers/reparameterizable_conv2d.py index df32aa2..4145d8e 100644 --- a/kimm/_src/layers/mobile_one_conv2d.py +++ b/kimm/_src/layers/reparameterizable_conv2d.py @@ -14,7 +14,7 @@ @kimm_export(parent_path=["kimm.layers"]) @keras.saving.register_keras_serializable(package="kimm") -class MobileOneConv2D(Layer): +class ReparameterizableConv2D(Layer): def __init__( self, filters, @@ -22,66 +22,73 @@ def __init__( strides=(1, 1), padding=None, has_skip: bool = True, + has_scale: bool = True, use_depthwise: bool = False, branch_size: int = 1, reparameterized: bool = False, data_format=None, activation=None, + name=None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(name=name, **kwargs) self.filters = filters self.kernel_size = standardize_tuple(kernel_size, 2, "kernel_size") self.strides = standardize_tuple(strides, 2, "strides") self.padding = padding self.has_skip = has_skip + self.has_scale = has_scale self.use_depthwise = use_depthwise self.branch_size = branch_size - self._reparameterized = reparameterized + self.reparameterized = reparameterized self.data_format = standardize_data_format(data_format) self.activation = activation if self.kernel_size[0] != self.kernel_size[1]: raise ValueError( - "The value of kernel_size must be the same. " + "The values of `kernel_size` must be the same. " f"Received: kernel_size={kernel_size}" ) if self.strides[0] != self.strides[1]: raise ValueError( - "The value of strides must be the same. " + "The values of `strides` must be the same. " f"Received: strides={strides}" ) - if has_skip is True and (self.strides[0] != 1 or self.strides[1] != 1): + if has_skip is True and self.strides[0] != 1: raise ValueError( - "strides must be `1` when `has_skip=True`. " + "When `has_skip=True`, `strides` must be `1`. " f"Received: has_skip={has_skip}, strides={strides}" ) - self.zero_padding = layers.Identity(dtype=self.dtype_policy) - if padding is None: - padding = "same" - if self.strides[0] > 1: - padding = "valid" + # Configure zero padding + self.zero_padding: typing.Optional[layers.ZeroPadding2D] = None + if self.padding is None: + if self.strides[0] > 1 and self.kernel_size[0] > 1: + self.padding = "valid" self.zero_padding = layers.ZeroPadding2D( - (self.kernel_size[0] // 2, self.kernel_size[1] // 2), + self.kernel_size[0] // 2, data_format=self.data_format, dtype=self.dtype_policy, name=f"{self.name}_pad", ) - self.padding = padding - else: - self.padding = padding + else: + self.padding = "same" + + # Configure filters_axis + self.filters_axis = -1 if self.data_format == "channels_last" else -3 - channel_axis = -1 if self.data_format == "channels_last" else -3 + # Build layers + bn_momentum, bn_epsilon = 0.9, 1e-5 # Defaults to torch's default - # Build layers (rep_conv2d, identity, conv_kxk, conv_scale) - self.rep_conv2d: typing.Optional[layers.Conv2D] = None - self.identity: typing.Optional[layers.BatchNormalization] = None - self.conv_kxk: typing.Optional[typing.List[Sequential]] = None + self.reparameterized_conv2d: typing.Optional[layers.Conv2D] = None + self.skip: typing.Optional[layers.BatchNormalization] = None self.conv_scale: typing.Optional[Sequential] = None - if self._reparameterized: - self.rep_conv2d = self._get_conv2d( - use_depthwise, + self.conv_kxk: typing.List[Sequential] = [] + self.act: typing.Optional[layers.Activation] = None + + if self.reparameterized: + self.reparameterized_conv2d = self._get_conv2d_layer( + self.use_depthwise, self.filters, self.kernel_size, self.strides, @@ -90,25 +97,42 @@ def __init__( name=f"{self.name}_reparam_conv", ) else: - # Skip connection + # Skip branch if self.has_skip: - self.identity = layers.BatchNormalization( - axis=channel_axis, - momentum=0.9, - epsilon=1e-5, + self.skip = layers.BatchNormalization( + axis=self.filters_axis, + momentum=bn_momentum, + epsilon=bn_epsilon, dtype=self.dtype_policy, - name=f"{self.name}_identity", + name=f"{self.name}_skip", ) - else: - self.identity = None - - # Convoluation branches - self.conv_kxk = [] + # Scale branch + if self.has_scale: + self.conv_scale = Sequential( + [ + self._get_conv2d_layer( + self.use_depthwise, + self.filters, + 1, + self.strides, + self.padding, + use_bias=False, + ), + layers.BatchNormalization( + axis=self.filters_axis, + momentum=bn_momentum, + epsilon=bn_epsilon, + dtype=self.dtype_policy, + ), + ], + name=f"{self.name}_conv_scale", + ) + # Overparameterized branch for i in range(self.branch_size): self.conv_kxk.append( Sequential( [ - self._get_conv2d( + self._get_conv2d_layer( self.use_depthwise, self.filters, self.kernel_size, @@ -117,61 +141,38 @@ def __init__( use_bias=False, ), layers.BatchNormalization( - axis=channel_axis, - momentum=0.9, - epsilon=1e-5, + axis=self.filters_axis, + momentum=bn_momentum, + epsilon=bn_epsilon, dtype=self.dtype_policy, ), ], name=f"{self.name}_conv_kxk_{i}", ) ) - - # Scale branch - self.conv_scale = None - if self.kernel_size[0] > 1: - self.conv_scale = Sequential( - [ - self._get_conv2d( - self.use_depthwise, - self.filters, - 1, - self.strides, - self.padding, - use_bias=False, - ), - layers.BatchNormalization( - axis=channel_axis, - momentum=0.9, - epsilon=1e-5, - dtype=self.dtype_policy, - ), - ], - name=f"{self.name}_conv_scale", - ) - - if activation is None: - self.act = layers.Identity(dtype=self.dtype_policy) - else: + if activation is not None: self.act = layers.Activation(activation, dtype=self.dtype_policy) - # Internal parameters for `_get_reparameterized_weights_from_layer` - self._input_channels = None - self._rep_kernel_shape = None - - # Attach extra layers - self.extra_layers = [] - if self.rep_conv2d is not None: - self.extra_layers.append(self.rep_conv2d) - if self.identity is not None: - self.extra_layers.append(self.identity) - if self.conv_kxk is not None: - self.extra_layers.extend(self.conv_kxk) + @property + def _sublayers(self): + """An internal api for weights exporting. + + Generally, you don't need this. + """ + sublayers = [] + if self.reparameterized_conv2d is not None: + sublayers.append(self.reparameterized_conv2d) + if self.skip is not None: + sublayers.append(self.skip) if self.conv_scale is not None: - self.extra_layers.append(self.conv_scale) - self.extra_layers.append(self.act) + sublayers.append(self.conv_scale) + if self.conv_kxk is not None: + sublayers.extend(self.conv_kxk) + if self.act is not None: + sublayers.append(self.act) + return sublayers - def _get_conv2d( + def _get_conv2d_layer( self, use_depthwise, filters, @@ -204,63 +205,67 @@ def _get_conv2d( ) def build(self, input_shape): - channel_axis = -1 if self.data_format == "channels_last" else -3 + input_filters = input_shape[self.filters_axis] + if self.use_depthwise and input_filters != self.filters: + raise ValueError( + "When `use_depthwise=True`, `filters` must be the same as " + f"input filters. Received: input_shape={input_shape}, " + f"filters={self.filters}" + ) if isinstance(self.zero_padding, layers.ZeroPadding2D): - padded_shape = self.zero_padding.compute_output_shape(input_shape) - else: - padded_shape = input_shape + input_shape = self.zero_padding.compute_output_shape(input_shape) - if self.rep_conv2d is not None: - self.rep_conv2d.build(padded_shape) - if self.identity is not None: - self.identity.build(input_shape) - if self.conv_kxk is not None: - for layer in self.conv_kxk: - layer.build(padded_shape) + if self.reparameterized_conv2d is not None: + self.reparameterized_conv2d.build(input_shape) + + if self.skip is not None: + self.skip.build(input_shape) if self.conv_scale is not None: self.conv_scale.build(input_shape) + for layer in self.conv_kxk: + layer.build(input_shape) # Update internal parameters - self._input_channels = input_shape[channel_axis] - if self.conv_kxk is not None: - self._rep_kernel_shape = self.conv_kxk[0].layers[0].kernel.shape + self.input_filters = input_filters self.built = True - def call(self, inputs, **kwargs): - x = ops.cast(inputs, self.compute_dtype) - padded_x = self.zero_padding(x) + def call(self, inputs, training=None, **kwargs): + x = inputs + padded_x = x - # Shortcut for reparameterized mode - if self._reparameterized: - return self.act(self.rep_conv2d(padded_x, **kwargs)) + if self.zero_padding is not None: + padded_x = self.zero_padding(x) - # Skip connection - identity_outputs = None - if self.identity is not None: - identity_outputs = self.identity(x, **kwargs) + # Shortcut for reparameterized=True + if self.reparameterized: + y = self.reparameterized_conv2d(padded_x) + if self.act is not None: + y = self.act(y) + return y + # Skip branch + y = None + if self.skip is not None: + y = self.skip(x, training=training) # Scale branch - scale_outputs = None if self.conv_scale is not None: - scale_outputs = self.conv_scale(x, **kwargs) - - # Conv branch - conv_outputs = scale_outputs - for layer in self.conv_kxk: - if conv_outputs is None: - conv_outputs = layer(padded_x, **kwargs) + scale_y = self.conv_scale(x, training=training) + if y is None: + y = scale_y else: - conv_outputs = layers.Add()( - [conv_outputs, layer(padded_x, **kwargs)] - ) - - if identity_outputs is not None: - outputs = layers.Add()([conv_outputs, identity_outputs]) - else: - outputs = conv_outputs - return self.act(outputs) + y = layers.Add(dtype=self.dtype_policy)([y, scale_y]) + # Overparameterized bracnh + for idx in range(self.branch_size): + over_y = self.conv_kxk[idx](padded_x, training=training) + if y is None: + y = over_y + else: + y = layers.Add(dtype=self.dtype_policy)([y, over_y]) + if self.act is not None: + y = self.act(y) + return y def get_config(self): config = super().get_config() @@ -271,9 +276,10 @@ def get_config(self): "strides": self.strides, "padding": self.padding, "has_skip": self.has_skip, + "has_scale": self.has_scale, "use_depthwise": self.use_depthwise, "branch_size": self.branch_size, - "reparameterized": self._reparameterized, + "reparameterized": self.reparameterized, "data_format": self.data_format, "activation": self.activation, "name": self.name, @@ -283,6 +289,7 @@ def get_config(self): def _get_reparameterized_weights_from_layer(self, layer): if isinstance(layer, Sequential): + # Check if not isinstance( layer.layers[0], (layers.Conv2D, layers.DepthwiseConv2D) ): @@ -298,32 +305,21 @@ def _get_reparameterized_weights_from_layer(self, layer): running_var = ops.convert_to_numpy(layer.layers[1].moving_variance) eps = layer.layers[1].epsilon elif isinstance(layer, layers.BatchNormalization): - if self._rep_kernel_shape is None: - raise ValueError( - "Remember to build the layer before performing" - "reparameterization. Failed to get valid " - "`self._rep_kernel_shape`." - ) - # Calculate identity tensor - kernel_value = ops.convert_to_numpy( - ops.zeros(self._rep_kernel_shape) + k = self.kernel_size[0] + input_filters = 1 if self.use_depthwise else self.input_filters + kernel = np.zeros( + shape=[k, k, input_filters, self.filters], dtype="float64" ) - kernel = kernel_value.copy() - if self.use_depthwise: - kernel = np.swapaxes(kernel, -2, -1) - for i in range(self._input_channels): + for i in range(self.input_filters): group_i = 0 if self.use_depthwise else i - kernel[ - self.kernel_size[0] // 2, - self.kernel_size[1] // 2, - group_i, - i, - ] = 1 + kernel[k // 2, k // 2, group_i, i] = 1 gamma = ops.convert_to_numpy(layer.gamma) beta = ops.convert_to_numpy(layer.beta) running_mean = ops.convert_to_numpy(layer.moving_mean) running_var = ops.convert_to_numpy(layer.moving_variance) eps = layer.epsilon + else: + raise NotImplementedError # use float64 for better precision kernel = kernel.astype("float64") @@ -341,6 +337,15 @@ def _get_reparameterized_weights_from_layer(self, layer): return kernel_final, beta - running_mean * gamma / std def get_reparameterized_weights(self): + # Get kernels and bias from skip branch + kernel_identity = 0.0 + bias_identity = 0.0 + if self.skip is not None: + ( + kernel_identity, + bias_identity, + ) = self._get_reparameterized_weights_from_layer(self.skip) + # Get kernels and bias from scale branch kernel_scale = 0.0 bias_scale = 0.0 @@ -354,16 +359,7 @@ def get_reparameterized_weights(self): kernel_scale, [[pad, pad], [pad, pad], [0, 0], [0, 0]] ) - # Get kernels and bias from skip branch - kernel_identity = 0.0 - bias_identity = 0.0 - if self.identity is not None: - ( - kernel_identity, - bias_identity, - ) = self._get_reparameterized_weights_from_layer(self.identity) - - # Get kernels and bias from conv branch + # Get kernels and bias from overparameterized branch kernel_conv = 0.0 bias_conv = 0.0 for i in range(self.branch_size): diff --git a/kimm/_src/layers/mobile_one_conv2d_test.py b/kimm/_src/layers/reparameterizable_conv2d_test.py similarity index 70% rename from kimm/_src/layers/mobile_one_conv2d_test.py rename to kimm/_src/layers/reparameterizable_conv2d_test.py index 1284739..d312586 100644 --- a/kimm/_src/layers/mobile_one_conv2d_test.py +++ b/kimm/_src/layers/reparameterizable_conv2d_test.py @@ -4,13 +4,14 @@ from keras import random from keras.src import testing -from kimm._src.layers.mobile_one_conv2d import MobileOneConv2D +from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D TEST_CASES = [ { "filters": 16, "kernel_size": 3, "has_skip": True, + "has_scale": True, "use_depthwise": False, "branch_size": 2, "data_format": "channels_last", @@ -23,6 +24,7 @@ "filters": 16, "kernel_size": 3, "has_skip": True, + "has_scale": True, "use_depthwise": True, "branch_size": 3, "data_format": "channels_last", @@ -35,6 +37,7 @@ "filters": 16, "kernel_size": 3, "has_skip": False, + "has_scale": True, "use_depthwise": False, "branch_size": 2, "data_format": "channels_last", @@ -47,6 +50,7 @@ "filters": 16, "kernel_size": 5, "has_skip": True, + "has_scale": True, "use_depthwise": False, "branch_size": 2, "data_format": "channels_last", @@ -59,6 +63,7 @@ "filters": 16, "kernel_size": 3, "has_skip": True, + "has_scale": True, "use_depthwise": False, "branch_size": 2, "data_format": "channels_first", @@ -67,10 +72,36 @@ "num_trainable_weights": 11, "num_non_trainable_weights": 8, }, + { + "filters": 16, + "kernel_size": 1, + "has_skip": True, + "has_scale": False, + "use_depthwise": False, + "branch_size": 2, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 8, + "num_non_trainable_weights": 6, + }, + { + "filters": 16, + "kernel_size": 1, + "has_skip": False, + "has_scale": False, + "use_depthwise": True, + "branch_size": 3, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 9, + "num_non_trainable_weights": 6, + }, ] -class MobileOneConv2DTest(testing.TestCase, parameterized.TestCase): +class ReparameterizableConv2DTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters(TEST_CASES) @pytest.mark.requires_trainable_backend def test_basic( @@ -78,6 +109,7 @@ def test_basic( filters, kernel_size, has_skip, + has_scale, use_depthwise, branch_size, data_format, @@ -95,11 +127,12 @@ def test_basic( "to be supported" ) self.run_layer_test( - MobileOneConv2D, + ReparameterizableConv2D, init_kwargs={ "filters": filters, "kernel_size": kernel_size, "has_skip": has_skip, + "has_scale": has_scale, "use_depthwise": use_depthwise, "branch_size": branch_size, "data_format": data_format, @@ -118,6 +151,7 @@ def test_get_reparameterized_weights( filters, kernel_size, has_skip, + has_scale, use_depthwise, branch_size, data_format, @@ -134,19 +168,21 @@ def test_get_reparameterized_weights( "Conv2D in tensorflow backend with 'channels_first' is limited " "to be supported" ) - layer = MobileOneConv2D( + layer = ReparameterizableConv2D( filters=filters, kernel_size=kernel_size, has_skip=has_skip, + has_scale=has_scale, use_depthwise=use_depthwise, branch_size=branch_size, data_format=data_format, ) layer.build(input_shape) - reparameterized_layer = MobileOneConv2D( + reparameterized_layer = ReparameterizableConv2D( filters=filters, kernel_size=kernel_size, has_skip=has_skip, + has_scale=has_scale, use_depthwise=use_depthwise, branch_size=branch_size, reparameterized=True, @@ -156,9 +192,22 @@ def test_get_reparameterized_weights( x = random.uniform(input_shape) kernel, bias = layer.get_reparameterized_weights() - reparameterized_layer.rep_conv2d.kernel.assign(kernel) - reparameterized_layer.rep_conv2d.bias.assign(bias) + reparameterized_layer.reparameterized_conv2d.kernel.assign(kernel) + reparameterized_layer.reparameterized_conv2d.bias.assign(bias) y1 = layer(x, training=False) y2 = reparameterized_layer(x, training=False) self.assertAllClose(y1, y2, atol=1e-3) + + def test_invalid_args(self): + layer = ReparameterizableConv2D( + filters=4, + kernel_size=3, + has_skip=False, + has_scale=False, + use_depthwise=True, + branch_size=1, + data_format="channels_last", + ) + with self.assertRaisesRegex(ValueError, "must be the same as"): + layer.build([1, 4, 4, 8]) diff --git a/kimm/_src/models/__init__.py b/kimm/_src/models/__init__.py index 7c32b20..b6f1ebb 100644 --- a/kimm/_src/models/__init__.py +++ b/kimm/_src/models/__init__.py @@ -3,6 +3,7 @@ from kimm._src.models import densenet from kimm._src.models import efficientnet from kimm._src.models import ghostnet +from kimm._src.models import ghostnet_v3 from kimm._src.models import hgnet from kimm._src.models import inception_next from kimm._src.models import inception_v3 diff --git a/kimm/_src/models/ghostnet.py b/kimm/_src/models/ghostnet.py index 8522bed..81cf913 100644 --- a/kimm/_src/models/ghostnet.py +++ b/kimm/_src/models/ghostnet.py @@ -431,7 +431,7 @@ def __init__( @kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) -class GhostNet050(GhostNetVariant): +class GhostNetW050(GhostNetVariant): available_weights = [] # Parameters @@ -441,7 +441,7 @@ class GhostNet050(GhostNetVariant): @kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) -class GhostNet100(GhostNetVariant): +class GhostNetW100(GhostNetVariant): available_weights = [ ( "imagenet", @@ -457,7 +457,7 @@ class GhostNet100(GhostNetVariant): @kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) -class GhostNet130(GhostNetVariant): +class GhostNetW130(GhostNetVariant): available_weights = [] # Parameters @@ -467,7 +467,7 @@ class GhostNet130(GhostNetVariant): @kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) -class GhostNet100V2(GhostNetVariant): +class GhostNetV2W100(GhostNetVariant): available_weights = [ ( "imagenet", @@ -483,7 +483,7 @@ class GhostNet100V2(GhostNetVariant): @kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) -class GhostNet130V2(GhostNetVariant): +class GhostNetV2W130(GhostNetVariant): available_weights = [ ( "imagenet", @@ -499,7 +499,7 @@ class GhostNet130V2(GhostNetVariant): @kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) -class GhostNet160V2(GhostNetVariant): +class GhostNetV2W160(GhostNetVariant): available_weights = [ ( "imagenet", @@ -514,9 +514,9 @@ class GhostNet160V2(GhostNetVariant): version = "v2" -add_model_to_registry(GhostNet050) -add_model_to_registry(GhostNet100, "imagenet") -add_model_to_registry(GhostNet130) -add_model_to_registry(GhostNet100V2, "imagenet") -add_model_to_registry(GhostNet130V2, "imagenet") -add_model_to_registry(GhostNet160V2, "imagenet") +add_model_to_registry(GhostNetW050) +add_model_to_registry(GhostNetW100, "imagenet") +add_model_to_registry(GhostNetW130) +add_model_to_registry(GhostNetV2W100, "imagenet") +add_model_to_registry(GhostNetV2W130, "imagenet") +add_model_to_registry(GhostNetV2W160, "imagenet") diff --git a/kimm/_src/models/ghostnet_v3.py b/kimm/_src/models/ghostnet_v3.py new file mode 100644 index 0000000..cfe3a17 --- /dev/null +++ b/kimm/_src/models/ghostnet_v3.py @@ -0,0 +1,513 @@ +import typing +import warnings + +import keras +from keras import backend +from keras import layers +from keras import ops + +from kimm._src.blocks.conv2d import apply_conv2d_block +from kimm._src.blocks.squeeze_and_excitation import apply_se_block +from kimm._src.kimm_export import kimm_export +from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D +from kimm._src.models.base_model import BaseModel +from kimm._src.utils.make_divisble import make_divisible +from kimm._src.utils.model_registry import add_model_to_registry + +DEFAULT_CONFIG = [ + # k, t, c, SE, s + # stage1 + [ + [3, 16, 16, 0, 1], + ], + # stage2 + [ + [3, 48, 24, 0, 2], + ], + [ + [3, 72, 24, 0, 1], + ], + # stage3 + [ + [5, 72, 40, 0.25, 2], + ], + [ + [5, 120, 40, 0.25, 1], + ], + # stage4 + [ + [3, 240, 80, 0, 2], + ], + [ + [3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1], + ], + # stage5 + [ + [5, 672, 160, 0.25, 2], + ], + [ + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + ], +] + + +def apply_short_block( + inputs, + output_channels: int, + kernel_size: int = 1, + strides: int = 1, + name="short_block", +): + x = inputs + x = apply_conv2d_block( + x, + output_channels, + kernel_size, + strides, + activation=None, + name=f"{name}_0", + ) + x = apply_conv2d_block( + x, + output_channels, + (1, 5), + 1, + activation=None, + use_depthwise=True, + padding="same", + name=f"{name}_1", + ) + x = apply_conv2d_block( + x, + output_channels, + (5, 1), + 1, + activation=None, + use_depthwise=True, + padding="same", + name=f"{name}_2", + ) + return x + + +def apply_ghost_block_v3( + inputs, + output_channels: int, + expand_ratio: float = 2.0, + kernel_size: int = 1, + depthwise_kernel_size: int = 3, + strides: int = 1, + activation="relu", + mode="ori", + reparameterized: bool = False, + name="ghost_block_v3", +): + assert mode in ("ori", "ori_shortcut_mul_conv15") + + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + hidden_channels_1 = int(ops.ceil(output_channels / expand_ratio)) + hidden_channels_2 = int(hidden_channels_1 * (expand_ratio - 1.0)) + input_channels = inputs.shape[channels_axis] + has_skip1 = input_channels == hidden_channels_1 and strides == 1 + has_skip2 = hidden_channels_1 == hidden_channels_2 + has_scale1 = kernel_size > 1 + has_scale2 = depthwise_kernel_size > 1 + + x = inputs + residual = inputs + + x1 = ReparameterizableConv2D( + hidden_channels_1, + kernel_size, + strides, + has_skip=has_skip1, + has_scale=has_scale1, + branch_size=3, + reparameterized=reparameterized, + activation=activation, + name=f"{name}_primary_conv", + )(x) + x2 = ReparameterizableConv2D( + hidden_channels_2, + depthwise_kernel_size, + 1, + has_skip=has_skip2, + has_scale=has_scale2, + use_depthwise=True, + branch_size=3, + reparameterized=reparameterized, + activation=activation, + name=f"{name}_cheap_operation", + )(x1) + out = layers.Concatenate(axis=channels_axis)([x1, x2]) + + if mode == "ori_shortcut_mul_conv15": + if channels_axis == -1: + out = out[..., :output_channels] + h, w = out.shape[-3], out.shape[-2] + else: + out = out[:, :output_channels, ...] + h, w = out.shape[-2], out.shape[-1] + residual = layers.AveragePooling2D(2, 2)(x) + residual = apply_short_block( + residual, + output_channels, + kernel_size, + strides, + name=f"{name}_short_conv", + ) + residual = layers.Activation("sigmoid")(residual) + residual = ops.image.resize( + residual, + size=(h, w), + interpolation="nearest", + data_format=backend.image_data_format(), + ) + out = layers.Multiply()([out, residual]) + + return out + + +def apply_ghost_bottleneck_v3( + inputs, + hidden_channels: int, + output_channels: int, + depthwise_kernel_size: int = 3, + strides: int = 1, + se_ratio: float = 0.0, + activation="relu", + pw_ghost_mode="ori", + reparameterized: bool = False, + name="ghost_bottlenect", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] + has_skip = strides == 1 + has_scale = depthwise_kernel_size > 1 + has_se = se_ratio is not None and se_ratio > 0.0 + + x = inputs + shortcut = inputs + + # Point-wise expansion + x = apply_ghost_block_v3( + x, + hidden_channels, + activation=activation, + mode=pw_ghost_mode, + reparameterized=reparameterized, + name=f"{name}_ghost1", + ) + + # Depth-wise + if strides > 1: + x = ReparameterizableConv2D( + hidden_channels, + depthwise_kernel_size, + strides=strides, + has_skip=has_skip, + has_scale=has_scale, + use_depthwise=True, + branch_size=3, + reparameterized=reparameterized, + activation=None, + name=f"{name}_conv_dw", + )(x) + + # Squeeze-and-excitation + if has_se: + x = apply_se_block( + x, + se_ratio, + gate_activation="hard_sigmoid", + make_divisible_number=4, + name=f"{name}_se", + ) + + # Point-wise + x = apply_ghost_block_v3( + x, + output_channels, + activation=None, + mode="ori", + reparameterized=reparameterized, + name=f"{name}_ghost2", + ) + + # Shortcut + if input_channels != output_channels or strides > 1: + shortcut = apply_conv2d_block( + shortcut, + input_channels, + depthwise_kernel_size, + strides, + activation=None, + use_depthwise=True, + name=f"{name}_shortcut1", + ) + shortcut = apply_conv2d_block( + shortcut, + output_channels, + 1, + 1, + activation=None, + padding="valid", + name=f"{name}_shortcut2", + ) + + out = layers.Add(name=name)([x, shortcut]) + return out + + +@keras.saving.register_keras_serializable(package="kimm") +class GhostNetV3(BaseModel): + # Updated weights: use ReparameterizableConv2D + default_origin = "https://github.com/james77777778/keras-image-models/releases/download/0.1.2/" + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(9), [2, 4, 4, 8, 8, 16, 16, 32, 32]) + ], + ] + + def __init__( + self, + width: float = 1.0, + config: typing.Union[str, typing.List] = "default", + reparameterized: bool = False, + input_tensor=None, + **kwargs, + ): + _available_configs = ["default"] + if config == "default": + _config = DEFAULT_CONFIG + else: + raise ValueError( + f"config must be one of {_available_configs} using string. " + f"Received: config={config}" + ) + + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + require_flatten=self._include_top, + static_shape=True, + ) + x = inputs + + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # Stem + stem_channels = make_divisible(16 * width, 4) + x = apply_conv2d_block( + x, stem_channels, 3, 2, activation="relu", name="conv_stem" + ) + features["STEM_S2"] = x + + # Blocks + total_layer_idx = 0 + current_stride = 2 + for current_block_idx, cfg in enumerate(_config): + for current_layer_idx, (k, e, c, se, s) in enumerate(cfg): + output_channels = make_divisible(c * width, 4) + hidden_channels = make_divisible(e * width, 4) + pw_ghost_mode = ( + "ori" if total_layer_idx <= 1 else "ori_shortcut_mul_conv15" + ) + name = f"blocks_{current_block_idx}_{current_layer_idx}" + x = apply_ghost_bottleneck_v3( + x, + hidden_channels, + output_channels, + k, + s, + se_ratio=se, + pw_ghost_mode=pw_ghost_mode, + reparameterized=reparameterized, + name=name, + ) + total_layer_idx += 1 + current_stride *= s + features[f"BLOCK{current_block_idx}_S{current_stride}"] = x + + # Last block + output_channels = make_divisible(e * width, 4) + x = apply_conv2d_block( + x, + output_channels, + 1, + activation="relu", + name=f"blocks_{current_block_idx+1}", + ) + + # Head + x = self.build_head(x) + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.width = width + self.config = config + self.reparameterized = reparameterized + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( + inputs + ) + x = layers.Conv2D( + 1280, 1, 1, use_bias=True, activation="relu", name="conv_head" + )(x) + x = layers.Flatten()(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "config": self.config, + "reparameterized": self.reparameterized, + } + ) + return config + + def fix_config(self, config): + unused_kwargs = ["width", "config"] + for k in unused_kwargs: + config.pop(k, None) + return config + + def get_reparameterized_model(self): + config = self.get_config() + config["reparameterized"] = True + config["weights"] = None + model = GhostNetV3(**config) + for layer, rep_layer in zip(self.layers, model.layers): + if hasattr(layer, "get_reparameterized_weights"): + kernel, bias = layer.get_reparameterized_weights() + rep_layer.reparameterized_conv2d.kernel.assign(kernel) + rep_layer.reparameterized_conv2d.bias.assign(bias) + else: + for weight, target_weight in zip( + layer.weights, rep_layer.weights + ): + target_weight.assign(weight) + return model + + +# Model Definition + + +class GhostNetV3Variant(GhostNetV3): + # Parameters + width = None + config = None + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, # Defaults to 0.0 + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: typing.Optional[str] = None, + **kwargs, + ): + if type(self) is GhostNetV3Variant: + raise NotImplementedError( + f"Cannot instantiate base class: {self.__class__.__name__}. " + "You should use its subclasses." + ) + kwargs = self.fix_config(kwargs) + if len(getattr(self, "available_weights", [])) == 0: + warnings.warn( + f"{self.__class__.__name__} doesn't have pretrained weights " + f"for '{weights}'." + ) + weights = None + super().__init__( + width=self.width, + config=self.config, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name or str(self.__class__.__name__), + **kwargs, + ) + + +@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) +class GhostNetV3W050(GhostNetV3Variant): + available_weights = [] + + # Parameters + width = 0.5 + config = "default" + + +@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) +class GhostNetV3W100(GhostNetV3Variant): + available_weights = [ + ( + "imagenet", + GhostNetV3.default_origin, + "ghostnetv3w100_ghostnetv3-1.0.keras", + ) + ] + + # Parameters + width = 1.0 + config = "default" + + +@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) +class GhostNetV3W130(GhostNetV3Variant): + available_weights = [] + + # Parameters + width = 1.3 + config = "default" + + +@kimm_export(parent_path=["kimm.models", "kimm.models.ghostnet"]) +class GhostNetV3W160(GhostNetV3Variant): + available_weights = [] + + # Parameters + width = 1.6 + config = "default" + + +add_model_to_registry(GhostNetV3W050) +add_model_to_registry(GhostNetV3W100, "imagenet") +add_model_to_registry(GhostNetV3W130) +add_model_to_registry(GhostNetV3W160) diff --git a/kimm/_src/models/mobileone.py b/kimm/_src/models/mobileone.py index 753da35..797d1ca 100644 --- a/kimm/_src/models/mobileone.py +++ b/kimm/_src/models/mobileone.py @@ -4,13 +4,15 @@ from keras import backend from kimm._src.kimm_export import kimm_export -from kimm._src.layers.mobile_one_conv2d import MobileOneConv2D +from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D from kimm._src.models.base_model import BaseModel from kimm._src.utils.model_registry import add_model_to_registry @keras.saving.register_keras_serializable(package="kimm") class MobileOne(BaseModel): + # Updated weights: use ReparameterizableConv2D + default_origin = "https://github.com/james77777778/keras-image-models/releases/download/0.1.2/" available_feature_keys = [ "STEM_S2", *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], @@ -53,11 +55,12 @@ def __init__( features = {} # stem - x = MobileOneConv2D( + x = ReparameterizableConv2D( stem_channels, 3, 2, has_skip=False, + branch_size=1, reparameterized=reparameterized, activation="relu", name="stem", @@ -81,7 +84,7 @@ def __init__( name1 = f"stages_{current_stage_idx}_{current_block_idx}" name2 = f"stages_{current_stage_idx}_{current_block_idx+1}" # Depthwise - x = MobileOneConv2D( + x = ReparameterizableConv2D( input_channels, 3, strides, @@ -93,11 +96,12 @@ def __init__( name=name1, )(x) # Pointwise - x = MobileOneConv2D( + x = ReparameterizableConv2D( c, 1, 1, has_skip=has_skip2, + has_scale=False, use_depthwise=False, branch_size=branch_size, reparameterized=reparameterized, @@ -150,14 +154,14 @@ def get_reparameterized_model(self): config["reparameterized"] = True config["weights"] = None model = MobileOne(**config) - for layer, reparameterized_layer in zip(self.layers, model.layers): + for layer, rep_layer in zip(self.layers, model.layers): if hasattr(layer, "get_reparameterized_weights"): kernel, bias = layer.get_reparameterized_weights() - reparameterized_layer.rep_conv2d.kernel.assign(kernel) - reparameterized_layer.rep_conv2d.bias.assign(bias) + rep_layer.reparameterized_conv2d.kernel.assign(kernel) + rep_layer.reparameterized_conv2d.bias.assign(bias) else: for weight, target_weight in zip( - layer.weights, reparameterized_layer.weights + layer.weights, rep_layer.weights ): target_weight.assign(weight) return model diff --git a/kimm/_src/models/models_test.py b/kimm/_src/models/models_test.py index 8240d89..eab679d 100644 --- a/kimm/_src/models/models_test.py +++ b/kimm/_src/models/models_test.py @@ -247,8 +247,8 @@ def test_weights_invalid_string(self): ), # ghostnet ( - kimm_models.ghostnet.GhostNet100.__name__, - kimm_models.ghostnet.GhostNet100, + kimm_models.ghostnet.GhostNetW100.__name__, + kimm_models.ghostnet.GhostNetV2W100, 224, [ ("STEM_S2", [1, 112, 112, 16]), @@ -259,8 +259,20 @@ def test_weights_invalid_string(self): ], ), ( - kimm_models.ghostnet.GhostNet100V2.__name__, - kimm_models.ghostnet.GhostNet100V2, + kimm_models.ghostnet.GhostNetV2W100.__name__, + kimm_models.ghostnet.GhostNetV2W100, + 224, + [ + ("STEM_S2", [1, 112, 112, 16]), + ("BLOCK1_S4", [1, 56, 56, 24]), + ("BLOCK3_S8", [1, 28, 28, 40]), + ("BLOCK5_S16", [1, 14, 14, 80]), + ("BLOCK7_S32", [1, 7, 7, 160]), + ], + ), + ( + kimm_models.ghostnet_v3.GhostNetV3W100.__name__, + kimm_models.ghostnet_v3.GhostNetV3W100, 224, [ ("STEM_S2", [1, 112, 112, 16]), @@ -598,6 +610,11 @@ def test_predict( self.assertEqual(list(y.shape), [1, 1000]) @parameterized.named_parameters( + ( + kimm_models.ghostnet_v3.GhostNetV3W050.__name__, + kimm_models.ghostnet_v3.GhostNetV3W050, + 224, + ), ( kimm_models.repvgg.RepVGGA0.__name__, kimm_models.repvgg.RepVGGA0, diff --git a/kimm/_src/models/repvgg.py b/kimm/_src/models/repvgg.py index 0ee2731..036210c 100644 --- a/kimm/_src/models/repvgg.py +++ b/kimm/_src/models/repvgg.py @@ -4,13 +4,15 @@ from keras import backend from kimm._src.kimm_export import kimm_export -from kimm._src.layers.rep_conv2d import RepConv2D +from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D from kimm._src.models.base_model import BaseModel from kimm._src.utils.model_registry import add_model_to_registry @keras.saving.register_keras_serializable(package="kimm") class RepVGG(BaseModel): + # Updated weights: use ReparameterizableConv2D + default_origin = "https://github.com/james77777778/keras-image-models/releases/download/0.1.2/" available_feature_keys = [ "STEM_S2", *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], @@ -53,11 +55,12 @@ def __init__( features = {} # stem - x = RepConv2D( + x = ReparameterizableConv2D( stem_channels, 3, 2, has_skip=False, + branch_size=1, reparameterized=reparameterized, activation="relu", name="stem", @@ -77,11 +80,12 @@ def __init__( input_channels = x.shape[channels_axis] has_skip = input_channels == c and strides == 1 name = f"stages_{current_stage_idx}_{current_block_idx}" - x = RepConv2D( + x = ReparameterizableConv2D( c, 3, strides, has_skip=has_skip, + branch_size=1, reparameterized=reparameterized, activation="relu", name=name, @@ -128,14 +132,14 @@ def get_reparameterized_model(self): config["reparameterized"] = True config["weights"] = None model = RepVGG(**config) - for layer, reparameterized_layer in zip(self.layers, model.layers): + for layer, rep_layer in zip(self.layers, model.layers): if hasattr(layer, "get_reparameterized_weights"): kernel, bias = layer.get_reparameterized_weights() - reparameterized_layer.rep_conv2d.kernel.assign(kernel) - reparameterized_layer.rep_conv2d.bias.assign(bias) + rep_layer.reparameterized_conv2d.kernel.assign(kernel) + rep_layer.reparameterized_conv2d.bias.assign(bias) else: for weight, target_weight in zip( - layer.weights, reparameterized_layer.weights + layer.weights, rep_layer.weights ): target_weight.assign(weight) return model @@ -287,7 +291,7 @@ class RepVGGB2(RepVGGVariant): @kimm_export(parent_path=["kimm.models", "kimm.models.repvgg"]) -class RepVGGB3(RepVGG): +class RepVGGB3(RepVGGVariant): available_weights = [ ( "imagenet", diff --git a/kimm/_src/utils/model_utils.py b/kimm/_src/utils/model_utils.py index f1a370a..32b9831 100644 --- a/kimm/_src/utils/model_utils.py +++ b/kimm/_src/utils/model_utils.py @@ -17,16 +17,12 @@ def get_reparameterized_model(model: BaseModel): config["reparameterized"] = True config["weights"] = None reparameterized_model = type(model).from_config(config) - for layer, reparameterized_layer in zip( - model.layers, reparameterized_model.layers - ): + for layer, rep_layer in zip(model.layers, reparameterized_model.layers): if hasattr(layer, "get_reparameterized_weights"): kernel, bias = layer.get_reparameterized_weights() - reparameterized_layer.rep_conv2d.kernel.assign(kernel) - reparameterized_layer.rep_conv2d.bias.assign(bias) + rep_layer.reparameterized_conv2d.kernel.assign(kernel) + rep_layer.reparameterized_conv2d.bias.assign(bias) else: - for weight, target_weight in zip( - layer.weights, reparameterized_layer.weights - ): + for weight, target_weight in zip(layer.weights, rep_layer.weights): target_weight.assign(weight) return reparameterized_model diff --git a/kimm/_src/utils/timm_utils.py b/kimm/_src/utils/timm_utils.py index c5d80c8..427d162 100644 --- a/kimm/_src/utils/timm_utils.py +++ b/kimm/_src/utils/timm_utils.py @@ -47,8 +47,8 @@ def separate_keras_weights(keras_model: keras.Model): trainable_weights = [] non_trainable_weights = [] for layer in keras_model.layers: - if hasattr(layer, "extra_layers"): - for sub_layer in layer.extra_layers: + if hasattr(layer, "_sublayers"): + for sub_layer in layer._sublayers: sub_layer: keras.Layer for weight in sub_layer.trainable_weights: trainable_weights.append( diff --git a/kimm/_src/version.py b/kimm/_src/version.py index b38714d..837cfd4 100644 --- a/kimm/_src/version.py +++ b/kimm/_src/version.py @@ -1,6 +1,6 @@ from kimm._src.kimm_export import kimm_export -__version__ = "0.2.1" +__version__ = "0.2.2" @kimm_export("kimm") diff --git a/kimm/layers/__init__.py b/kimm/layers/__init__.py index bdbe229..08822c4 100644 --- a/kimm/layers/__init__.py +++ b/kimm/layers/__init__.py @@ -7,6 +7,5 @@ from kimm._src.layers.attention import Attention from kimm._src.layers.layer_scale import LayerScale from kimm._src.layers.learnable_affine import LearnableAffine -from kimm._src.layers.mobile_one_conv2d import MobileOneConv2D from kimm._src.layers.position_embedding import PositionEmbedding -from kimm._src.layers.rep_conv2d import RepConv2D +from kimm._src.layers.reparameterizable_conv2d import ReparameterizableConv2D diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index 688f0cf..688e93f 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -47,12 +47,16 @@ from kimm._src.models.efficientnet import TinyNetC from kimm._src.models.efficientnet import TinyNetD from kimm._src.models.efficientnet import TinyNetE -from kimm._src.models.ghostnet import GhostNet050 -from kimm._src.models.ghostnet import GhostNet100 -from kimm._src.models.ghostnet import GhostNet100V2 -from kimm._src.models.ghostnet import GhostNet130 -from kimm._src.models.ghostnet import GhostNet130V2 -from kimm._src.models.ghostnet import GhostNet160V2 +from kimm._src.models.ghostnet import GhostNetV2W100 +from kimm._src.models.ghostnet import GhostNetV2W130 +from kimm._src.models.ghostnet import GhostNetV2W160 +from kimm._src.models.ghostnet import GhostNetW050 +from kimm._src.models.ghostnet import GhostNetW100 +from kimm._src.models.ghostnet import GhostNetW130 +from kimm._src.models.ghostnet_v3 import GhostNetV3W050 +from kimm._src.models.ghostnet_v3 import GhostNetV3W100 +from kimm._src.models.ghostnet_v3 import GhostNetV3W130 +from kimm._src.models.ghostnet_v3 import GhostNetV3W160 from kimm._src.models.hgnet import HGNetBase from kimm._src.models.hgnet import HGNetSmall from kimm._src.models.hgnet import HGNetTiny diff --git a/kimm/models/ghostnet/__init__.py b/kimm/models/ghostnet/__init__.py index d1ae684..6aa5e36 100644 --- a/kimm/models/ghostnet/__init__.py +++ b/kimm/models/ghostnet/__init__.py @@ -4,9 +4,13 @@ since your modifications would be overwritten. """ -from kimm._src.models.ghostnet import GhostNet050 -from kimm._src.models.ghostnet import GhostNet100 -from kimm._src.models.ghostnet import GhostNet100V2 -from kimm._src.models.ghostnet import GhostNet130 -from kimm._src.models.ghostnet import GhostNet130V2 -from kimm._src.models.ghostnet import GhostNet160V2 +from kimm._src.models.ghostnet import GhostNetV2W100 +from kimm._src.models.ghostnet import GhostNetV2W130 +from kimm._src.models.ghostnet import GhostNetV2W160 +from kimm._src.models.ghostnet import GhostNetW050 +from kimm._src.models.ghostnet import GhostNetW100 +from kimm._src.models.ghostnet import GhostNetW130 +from kimm._src.models.ghostnet_v3 import GhostNetV3W050 +from kimm._src.models.ghostnet_v3 import GhostNetV3W100 +from kimm._src.models.ghostnet_v3 import GhostNetV3W130 +from kimm._src.models.ghostnet_v3 import GhostNetV3W160 diff --git a/shell/export_models.sh b/shell/export_models.sh index a856728..9d0c3e5 100755 --- a/shell/export_models.sh +++ b/shell/export_models.sh @@ -9,6 +9,7 @@ python3 -m tools.convert_convnext_from_timm python3 -m tools.convert_densenet_from_timm python3 -m tools.convert_efficientnet_from_timm python3 -m tools.convert_ghostnet_from_timm +python3 -m tools.convert_ghostnet_v3_from_github python3 -m tools.convert_hgnet_from_timm python3 -m tools.convert_inception_next_from_timm python3 -m tools.convert_inception_v3_from_timm diff --git a/tools/README.md b/tools/README.md index cf0bce3..44d942f 100644 --- a/tools/README.md +++ b/tools/README.md @@ -18,8 +18,9 @@ Setup `gh` Upload the converted file ```bash -gh release upload ... [flags] +# --clobber means overwrite the existing file +gh release upload ... --clobber # For example: -gh release upload 0.1.0 exported/visiontransformertiny16_vit_tiny_patch16_384.keras +gh release upload 0.1.0 exported/* --clobber ``` diff --git a/tools/convert_ghostnet_v3_from_github.py b/tools/convert_ghostnet_v3_from_github.py new file mode 100644 index 0000000..29cd082 --- /dev/null +++ b/tools/convert_ghostnet_v3_from_github.py @@ -0,0 +1,307 @@ +""" +From: https://github.com/huawei-noah/Efficient-AI-Backbones +""" + +import os +import pathlib +import urllib.parse + +import keras +import numpy as np +import torch + +from kimm.models import ghostnet +from kimm.timm_utils import assign_weights +from kimm.timm_utils import is_same_weights +from kimm.timm_utils import separate_keras_weights +from kimm.timm_utils import separate_torch_state_dict +from tools.third_party.ghostnet_v3.ghostnetv3 import ghostnetv3 + +github_model_items = [ + ( + "ghostnetv3-1.0", + ghostnetv3, + dict(width=1.0), + "https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV3/ghostnetv3-1.0.pth.tar", + ), +] +keras_model_classes = [ + ghostnet.GhostNetV3W100, +] + +for github_model_item, keras_model_class in zip( + github_model_items, keras_model_classes +): + """ + Prepare timm model and keras model + """ + model_name, model_class, model_args, model_url = github_model_item + + input_shape = [224, 224, 3] + result = urllib.parse.urlparse(model_url) + filename = pathlib.Path(result.path).name + file_path = keras.utils.get_file( + fname=filename, origin=model_url, cache_subdir="kimm_models" + ) + state_dict = torch.load(file_path, map_location="cpu")["state_dict"] + torch_model = model_class(**model_args) + torch_model.load_state_dict(state_dict) + torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + state_dict + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + weights=None, + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + # exit() + + # for torch_name, (_, keras_name) in zip( + # non_trainable_state_dict.keys(), non_trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(non_trainable_state_dict.keys())) + # print(len(non_trainable_weights)) + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # skip reparam_conv + if "reparam_conv_conv2d" in keras_name: + continue + # Stem + torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") + torch_name = torch_name.replace("conv.stem.bn", "bn1") + # ReparameterizableConv2D + # primary + for i in range(3): + torch_name = torch_name.replace( + f"primary.conv.conv.kxk.{i}.kernel", + f"primary_rpr_conv.{i}.conv.weight", + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"primary.conv.conv.kxk.{i}.{a}", + f"primary_rpr_conv.{i}.bn.{b}", + ) + # cheap + torch_name = torch_name.replace( + "cheap.operation.skip", "cheap_rpr_skip" + ) + torch_name = torch_name.replace( + "cheap.operation.conv.scale.kernel", "cheap_rpr_scale.conv.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"cheap.operation.conv.scale.{a}", + f"cheap_rpr_scale.bn.{b}", + ) + for i in range(3): + torch_name = torch_name.replace( + f"cheap.operation.conv.kxk.{i}.kernel", + f"cheap_rpr_conv.{i}.conv.weight", + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"cheap.operation.conv.kxk.{i}.{a}", + f"cheap_rpr_conv.{i}.bn.{b}", + ) + # short + for i in range(3): + torch_name = torch_name.replace( + "short.conv.0.conv2d.kernel", "short_conv.0.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"short.conv.0.bn.{a}", + f"short_conv.1.{b}", + ) + torch_name = torch_name.replace( + "short.conv.1.dwconv2d.kernel", "short_conv.2.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"short.conv.1.bn.{a}", + f"short_conv.3.{b}", + ) + torch_name = torch_name.replace( + "short.conv.2.dwconv2d.kernel", "short_conv.4.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"short.conv.2.bn.{a}", f"short_conv.5.{b}" + ) + # Depth-wise + torch_name = torch_name.replace( + "conv.dw.conv.scale.kernel", "dw_rpr_scale.conv.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"conv.dw.conv.scale.{a}", + f"dw_rpr_scale.bn.{b}", + ) + for i in range(3): + torch_name = torch_name.replace( + f"conv.dw.conv.kxk.{i}.kernel", + f"dw_rpr_conv.{i}.conv.weight", + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"conv.dw.conv.kxk.{i}.{a}", + f"dw_rpr_conv.{i}.bn.{b}", + ) + # Squeeze-and-excitation + torch_name = torch_name.replace("se.conv.reduce", "se.conv_reduce") + torch_name = torch_name.replace("se.conv.expand", "se.conv_expand") + # Shortcut + torch_name = torch_name.replace( + "shortcut1.dwconv2d.kernel", "shortcut.0.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"shortcut1.bn.{a}", f"shortcut.1.{b}" + ) + torch_name = torch_name.replace( + "shortcut2.conv2d.kernel", "shortcut.2.weight" + ) + for pair in ( + ("gamma", "weight"), + ("beta", "bias"), + ("moving.mean", "running_mean"), + ("moving.variance", "running_var"), + ): + a, b = pair + torch_name = torch_name.replace( + f"shortcut2.bn.{a}", f"shortcut.3.{b}" + ) + + # Last block + torch_name = torch_name.replace("blocks.9.conv2d", "blocks.9.0.conv") + torch_name = torch_name.replace("blocks.9.bn", "blocks.9.0.bn1") + + # Head + torch_name = torch_name.replace("conv.head", "conv_head") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + torch_y = np.expand_dims(torch_y, axis=0) + keras_y = keras.ops.convert_to_numpy(keras_y) + # TODO: Error is large + np.testing.assert_allclose(torch_y, keras_y, atol=0.5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_mobileone_from_timm.py b/tools/convert_mobileone_from_timm.py index 24d9291..24f9e58 100644 --- a/tools/convert_mobileone_from_timm.py +++ b/tools/convert_mobileone_from_timm.py @@ -82,6 +82,15 @@ if "reparam_conv_conv2d" in keras_name: continue # mobile_one_conv2d + torch_name = torch_name.replace("skip.gamma", "identity.gamma") + torch_name = torch_name.replace("skip.beta", "identity.beta") + torch_name = torch_name.replace( + "conv.scale.kernel", "conv_scale.conv.kernel" + ) + torch_name = torch_name.replace( + "conv.scale.gamma", "conv_scale.bn.gamma" + ) + torch_name = torch_name.replace("conv.scale.beta", "conv_scale.bn.beta") if "conv.kxk" in torch_name and "kernel" in torch_name: torch_name = torch_name.replace("conv.kxk", "conv_kxk") torch_name = torch_name.replace("kernel", "conv.kernel") @@ -91,14 +100,19 @@ if "conv.kxk" in torch_name and "beta" in torch_name: torch_name = torch_name.replace("conv.kxk", "conv_kxk") torch_name = torch_name.replace("beta", "bn.beta") + # mobile_one_conv2d bn torch_name = torch_name.replace( - "conv.scale.kernel", "conv_scale.conv.kernel" + "skip.moving.mean", "identity.moving.mean" ) torch_name = torch_name.replace( - "conv.scale.gamma", "conv_scale.bn.gamma" + "skip.moving.variance", "identity.moving.variance" + ) + torch_name = torch_name.replace( + "conv.scale.moving.mean", "conv_scale.bn.moving.mean" + ) + torch_name = torch_name.replace( + "conv.scale.moving.variance", "conv_scale.bn.moving.variance" ) - torch_name = torch_name.replace("conv.scale.beta", "conv_scale.bn.beta") - # mobile_one_conv2d bn if "conv.kxk" in torch_name and "moving.mean" in torch_name: torch_name = torch_name.replace("conv.kxk", "conv_kxk") torch_name = torch_name.replace("moving.mean", "bn.moving.mean") @@ -107,12 +121,6 @@ torch_name = torch_name.replace( "moving.variance", "bn.moving.variance" ) - torch_name = torch_name.replace( - "conv.scale.moving.mean", "conv_scale.bn.moving.mean" - ) - torch_name = torch_name.replace( - "conv.scale.moving.variance", "conv_scale.bn.moving.variance" - ) # head torch_name = torch_name.replace("classifier", "head.fc") diff --git a/tools/convert_repvgg_from_timm.py b/tools/convert_repvgg_from_timm.py index 8454f8b..12c2452 100644 --- a/tools/convert_repvgg_from_timm.py +++ b/tools/convert_repvgg_from_timm.py @@ -86,28 +86,36 @@ if "reparam_conv_conv2d" in keras_name: continue # repconv2d + torch_name = torch_name.replace("skip.gamma", "identity.gamma") + torch_name = torch_name.replace("skip.beta", "identity.beta") torch_name = torch_name.replace( - "conv.kxk.kernel", "conv_kxk.conv.kernel" + "conv.scale.kernel", "conv_1x1.conv.kernel" ) - torch_name = torch_name.replace("conv.kxk.gamma", "conv_kxk.bn.gamma") - torch_name = torch_name.replace("conv.kxk.beta", "conv_kxk.bn.beta") + torch_name = torch_name.replace("conv.scale.gamma", "conv_1x1.bn.gamma") + torch_name = torch_name.replace("conv.scale.beta", "conv_1x1.bn.beta") torch_name = torch_name.replace( - "conv.1x1.kernel", "conv_1x1.conv.kernel" + "conv.kxk.0.kernel", "conv_kxk.conv.kernel" ) - torch_name = torch_name.replace("conv.1x1.gamma", "conv_1x1.bn.gamma") - torch_name = torch_name.replace("conv.1x1.beta", "conv_1x1.bn.beta") + torch_name = torch_name.replace("conv.kxk.0.gamma", "conv_kxk.bn.gamma") + torch_name = torch_name.replace("conv.kxk.0.beta", "conv_kxk.bn.beta") # repconv2d bn torch_name = torch_name.replace( - "conv.kxk.moving.mean", "conv_kxk.bn.moving.mean" + "skip.moving.mean", "identity.moving.mean" ) torch_name = torch_name.replace( - "conv.kxk.moving.variance", "conv_kxk.bn.moving.variance" + "skip.moving.variance", "identity.moving.variance" ) torch_name = torch_name.replace( - "conv.1x1.moving.mean", "conv_1x1.bn.moving.mean" + "conv.scale.moving.mean", "conv_1x1.bn.moving.mean" ) torch_name = torch_name.replace( - "conv.1x1.moving.variance", "conv_1x1.bn.moving.variance" + "conv.scale.moving.variance", "conv_1x1.bn.moving.variance" + ) + torch_name = torch_name.replace( + "conv.kxk.0.moving.mean", "conv_kxk.bn.moving.mean" + ) + torch_name = torch_name.replace( + "conv.kxk.0.moving.variance", "conv_kxk.bn.moving.variance" ) # head torch_name = torch_name.replace("classifier", "head.fc") diff --git a/tools/third_party/ghostnet_v3/ghostnetv3.py b/tools/third_party/ghostnet_v3/ghostnetv3.py new file mode 100644 index 0000000..94eb376 --- /dev/null +++ b/tools/third_party/ghostnet_v3/ghostnetv3.py @@ -0,0 +1,1061 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) + else: + return F.relu6(x + 3.0) / 6.0 + + +class SqueezeExcite(nn.Module): + def __init__( + self, + in_chs, + se_ratio=0.25, + reduced_base_chs=None, + act_layer=nn.ReLU, + gate_fn=hard_sigmoid, + divisor=4, + **_, + ): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible( + (reduced_base_chs or in_chs) * se_ratio, divisor + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__( + self, in_chs, out_chs, kernel_size, stride=1, act_layer=nn.ReLU + ): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d( + in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False + ) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +def gcd(a, b): + if a < b: + a, b = b, a + while a % b != 0: + c = a % b + a = b + b = c + return b + + +def MyNorm(dim): + return nn.GroupNorm(1, dim) + + +class GhostModule(nn.Module): + def __init__( + self, + inp, + oup, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + relu=True, + mode=None, + args=None, + ): + super(GhostModule, self).__init__() + # self.args=args + self.mode = mode + self.gate_loc = "before" + + self.inter_mode = "nearest" + self.scale = 1.0 + + self.infer_mode = False + self.num_conv_branches = 3 + self.dconv_scale = True + self.gate_fn = nn.Sigmoid() + + # if args.gate_fn=='hard_sigmoid': + # self.gate_fn=hard_sigmoid + # elif args.gate_fn=='sigmoid': + # self.gate_fn=nn.Sigmoid() + # elif args.gate_fn=='relu': + # self.gate_fn=nn.ReLU() + # elif args.gate_fn=='clip': + # self.gate_fn=myclip + # elif args.gate_fn=='tanh': + # self.gate_fn=nn.Tanh() + + if self.mode in ["ori"]: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + if self.infer_mode: + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + else: + self.primary_rpr_skip = ( + nn.BatchNorm2d(inp) + if inp == init_channels and stride == 1 + else None + ) + primary_rpr_conv = list() + for _ in range(self.num_conv_branches): + primary_rpr_conv.append( + self._conv_bn( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ) + ) + self.primary_rpr_conv = nn.ModuleList(primary_rpr_conv) + # Re-parameterizable scale branch + self.primary_rpr_scale = None + if kernel_size > 1: + self.primary_rpr_scale = self._conv_bn( + inp, init_channels, 1, 1, 0, bias=False + ) + self.primary_activation = ( + nn.ReLU(inplace=True) if relu else None + ) + + self.cheap_rpr_skip = ( + nn.BatchNorm2d(init_channels) + if init_channels == new_channels + else None + ) + cheap_rpr_conv = list() + for _ in range(self.num_conv_branches): + cheap_rpr_conv.append( + self._conv_bn( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ) + ) + self.cheap_rpr_conv = nn.ModuleList(cheap_rpr_conv) + # Re-parameterizable scale branch + self.cheap_rpr_scale = None + if dw_size > 1: + self.cheap_rpr_scale = self._conv_bn( + init_channels, + new_channels, + 1, + 1, + 0, + groups=init_channels, + bias=False, + ) + self.cheap_activation = nn.ReLU(inplace=True) if relu else None + self.in_channels = init_channels + self.groups = init_channels + self.kernel_size = dw_size + + elif self.mode in ["ori_shortcut_mul_conv15"]: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + self.short_conv = nn.Sequential( + nn.Conv2d( + inp, oup, kernel_size, stride, kernel_size // 2, bias=False + ), + nn.BatchNorm2d(oup), + nn.Conv2d( + oup, + oup, + kernel_size=(1, 5), + stride=1, + padding=(0, 2), + groups=oup, + bias=False, + ), + nn.BatchNorm2d(oup), + nn.Conv2d( + oup, + oup, + kernel_size=(5, 1), + stride=1, + padding=(2, 0), + groups=oup, + bias=False, + ), + nn.BatchNorm2d(oup), + ) + if self.infer_mode: + self.primary_conv = nn.Sequential( + nn.Conv2d( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + else: + self.primary_rpr_skip = ( + nn.BatchNorm2d(inp) + if inp == init_channels and stride == 1 + else None + ) + primary_rpr_conv = list() + for _ in range(self.num_conv_branches): + primary_rpr_conv.append( + self._conv_bn( + inp, + init_channels, + kernel_size, + stride, + kernel_size // 2, + bias=False, + ) + ) + self.primary_rpr_conv = nn.ModuleList(primary_rpr_conv) + # Re-parameterizable scale branch + self.primary_rpr_scale = None + if kernel_size > 1: + self.primary_rpr_scale = self._conv_bn( + inp, init_channels, 1, 1, 0, bias=False + ) + self.primary_activation = ( + nn.ReLU(inplace=True) if relu else None + ) + + self.cheap_rpr_skip = ( + nn.BatchNorm2d(init_channels) + if init_channels == new_channels + else None + ) + cheap_rpr_conv = list() + for _ in range(self.num_conv_branches): + cheap_rpr_conv.append( + self._conv_bn( + init_channels, + new_channels, + dw_size, + 1, + dw_size // 2, + groups=init_channels, + bias=False, + ) + ) + self.cheap_rpr_conv = nn.ModuleList(cheap_rpr_conv) + # Re-parameterizable scale branch + self.cheap_rpr_scale = None + if dw_size > 1: + self.cheap_rpr_scale = self._conv_bn( + init_channels, + new_channels, + 1, + 1, + 0, + groups=init_channels, + bias=False, + ) + self.cheap_activation = nn.ReLU(inplace=True) if relu else None + self.in_channels = init_channels + self.groups = init_channels + self.kernel_size = dw_size + + def forward(self, x): + if self.mode in ["ori"]: + if self.infer_mode: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + else: + identity_out = 0 + if self.primary_rpr_skip is not None: + identity_out = self.primary_rpr_skip(x) + scale_out = 0 + if self.primary_rpr_scale is not None and self.dconv_scale: + scale_out = self.primary_rpr_scale(x) + x1 = scale_out + identity_out + for ix in range(self.num_conv_branches): + x1 += self.primary_rpr_conv[ix](x) + if self.primary_activation is not None: + x1 = self.primary_activation(x1) + + cheap_identity_out = 0 + if self.cheap_rpr_skip is not None: + cheap_identity_out = self.cheap_rpr_skip(x1) + cheap_scale_out = 0 + if self.cheap_rpr_scale is not None and self.dconv_scale: + cheap_scale_out = self.cheap_rpr_scale(x1) + x2 = cheap_scale_out + cheap_identity_out + for ix in range(self.num_conv_branches): + x2 += self.cheap_rpr_conv[ix](x1) + if self.cheap_activation is not None: + x2 = self.cheap_activation(x2) + + out = torch.cat([x1, x2], dim=1) + return out + + elif self.mode in ["ori_shortcut_mul_conv15"]: + res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2)) + + if self.infer_mode: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + else: + identity_out = 0 + if self.primary_rpr_skip is not None: + identity_out = self.primary_rpr_skip(x) + scale_out = 0 + if self.primary_rpr_scale is not None and self.dconv_scale: + scale_out = self.primary_rpr_scale(x) + x1 = scale_out + identity_out + for ix in range(self.num_conv_branches): + x1 += self.primary_rpr_conv[ix](x) + if self.primary_activation is not None: + x1 = self.primary_activation(x1) + + cheap_identity_out = 0 + if self.cheap_rpr_skip is not None: + cheap_identity_out = self.cheap_rpr_skip(x1) + cheap_scale_out = 0 + if self.cheap_rpr_scale is not None and self.dconv_scale: + cheap_scale_out = self.cheap_rpr_scale(x1) + x2 = cheap_scale_out + cheap_identity_out + for ix in range(self.num_conv_branches): + x2 += self.cheap_rpr_conv[ix](x1) + if self.cheap_activation is not None: + x2 = self.cheap_activation(x2) + + out = torch.cat([x1, x2], dim=1) + + if self.gate_loc == "before": + return out[:, : self.oup, :, :] * F.interpolate( + self.gate_fn(res / self.scale), + size=out.shape[-2:], + mode=self.inter_mode, + ) # 'nearest' + else: + return out[:, : self.oup, :, :] * self.gate_fn( + F.interpolate( + res, size=out.shape[-2:], mode=self.inter_mode + ) + ) + + def reparameterize(self): + if self.infer_mode: + return + primary_kernel, primary_bias = self._get_kernel_bias_primary() + self.primary_conv = nn.Conv2d( + in_channels=self.primary_rpr_conv[0].conv.in_channels, + out_channels=self.primary_rpr_conv[0].conv.out_channels, + kernel_size=self.primary_rpr_conv[0].conv.kernel_size, + stride=self.primary_rpr_conv[0].conv.stride, + padding=self.primary_rpr_conv[0].conv.padding, + dilation=self.primary_rpr_conv[0].conv.dilation, + groups=self.primary_rpr_conv[0].conv.groups, + bias=True, + ) + self.primary_conv.weight.data = primary_kernel + self.primary_conv.bias.data = primary_bias + self.primary_conv = nn.Sequential( + self.primary_conv, + ( + self.primary_activation + if self.primary_activation is not None + else nn.Sequential() + ), + ) + + cheap_kernel, cheap_bias = self._get_kernel_bias_cheap() + self.cheap_operation = nn.Conv2d( + in_channels=self.cheap_rpr_conv[0].conv.in_channels, + out_channels=self.cheap_rpr_conv[0].conv.out_channels, + kernel_size=self.cheap_rpr_conv[0].conv.kernel_size, + stride=self.cheap_rpr_conv[0].conv.stride, + padding=self.cheap_rpr_conv[0].conv.padding, + dilation=self.cheap_rpr_conv[0].conv.dilation, + groups=self.cheap_rpr_conv[0].conv.groups, + bias=True, + ) + self.cheap_operation.weight.data = cheap_kernel + self.cheap_operation.bias.data = cheap_bias + + self.cheap_operation = nn.Sequential( + self.cheap_operation, + ( + self.cheap_activation + if self.cheap_activation is not None + else nn.Sequential() + ), + ) + + # Delete un-used branches + for para in self.parameters(): + para.detach_() + if hasattr(self, "primary_rpr_conv"): + self.__delattr__("primary_rpr_conv") + if hasattr(self, "primary_rpr_scale"): + self.__delattr__("primary_rpr_scale") + if hasattr(self, "primary_rpr_skip"): + self.__delattr__("primary_rpr_skip") + + if hasattr(self, "cheap_rpr_conv"): + self.__delattr__("cheap_rpr_conv") + if hasattr(self, "cheap_rpr_scale"): + self.__delattr__("cheap_rpr_scale") + if hasattr(self, "cheap_rpr_skip"): + self.__delattr__("cheap_rpr_skip") + + self.infer_mode = True + + def _get_kernel_bias_primary(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + :return: Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.primary_rpr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor( + self.primary_rpr_scale + ) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad( + kernel_scale, [pad, pad, pad, pad] + ) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.primary_rpr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor( + self.primary_rpr_skip + ) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.primary_rpr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _get_kernel_bias_cheap(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + :return: Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.cheap_rpr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor( + self.cheap_rpr_scale + ) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad( + kernel_scale, [pad, pad, pad, pad] + ) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.cheap_rpr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor( + self.cheap_rpr_skip + ) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.cheap_rpr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + :param branch: + :return: Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.groups + kernel_value = torch.zeros( + ( + self.in_channels, + input_dim, + self.kernel_size, + self.kernel_size, + ), + dtype=branch.weight.dtype, + device=branch.weight.device, + ) + for i in range(self.in_channels): + kernel_value[ + i, + i % input_dim, + self.kernel_size // 2, + self.kernel_size // 2, + ] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + bias=False, + ): + """Helper method to construct conv-batchnorm layers. + + :param kernel_size: Size of the convolution kernel. + :param padding: Zero-padding size. + :return: Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module( + "conv", + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=bias, + ), + ) + mod_list.add_module("bn", nn.BatchNorm2d(out_channels)) + return mod_list + + +class GhostBottleneck(nn.Module): + """Ghost bottleneck w/ optional SE""" + + def __init__( + self, + in_chs, + mid_chs, + out_chs, + dw_kernel_size=3, + stride=1, + act_layer=nn.ReLU, + se_ratio=0.0, + layer_id=None, + args=None, + ): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0.0 + self.stride = stride + + self.num_conv_branches = 3 + self.infer_mode = False + self.dconv_scale = True + + # Point-wise expansion + if layer_id <= 1: + self.ghost1 = GhostModule( + in_chs, mid_chs, relu=True, mode="ori", args=args + ) + else: + self.ghost1 = GhostModule( + in_chs, + mid_chs, + relu=True, + mode="ori_shortcut_mul_conv15", + args=args, + ) ####这里是扩张 mid_chs远大于in_chs + + # Depth-wise convolution + if self.stride > 1: + if self.infer_mode: + self.conv_dw = nn.Conv2d( + mid_chs, + mid_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + self.bn_dw = nn.BatchNorm2d(mid_chs) + else: + self.dw_rpr_skip = ( + nn.BatchNorm2d(mid_chs) if stride == 1 else None + ) + dw_rpr_conv = list() + for _ in range(self.num_conv_branches): + dw_rpr_conv.append( + self._conv_bn( + mid_chs, + mid_chs, + dw_kernel_size, + stride, + (dw_kernel_size - 1) // 2, + groups=mid_chs, + bias=False, + ) + ) + self.dw_rpr_conv = nn.ModuleList(dw_rpr_conv) + # Re-parameterizable scale branch + self.dw_rpr_scale = None + if dw_kernel_size > 1: + self.dw_rpr_scale = self._conv_bn( + mid_chs, mid_chs, 1, 2, 0, groups=mid_chs, bias=False + ) + self.kernel_size = dw_kernel_size + self.in_channels = mid_chs + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + # Point-wise linear projection + if layer_id <= 1: + self.ghost2 = GhostModule( + mid_chs, out_chs, relu=False, mode="ori", args=args + ) + else: + self.ghost2 = GhostModule( + mid_chs, out_chs, relu=False, mode="ori", args=args + ) + + # shortcut + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, + in_chs, + dw_kernel_size, + stride=stride, + padding=(dw_kernel_size - 1) // 2, + groups=in_chs, + bias=False, + ), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.stride > 1: + if self.infer_mode: + x = self.conv_dw(x) + x = self.bn_dw(x) + else: + dw_identity_out = 0 + if self.dw_rpr_skip is not None: + dw_identity_out = self.dw_rpr_skip(x) + dw_scale_out = 0 + if self.dw_rpr_scale is not None and self.dconv_scale: + dw_scale_out = self.dw_rpr_scale(x) + x1 = dw_scale_out + dw_identity_out + for ix in range(self.num_conv_branches): + x1 += self.dw_rpr_conv[ix](x) + x = x1 + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + def _conv_bn( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups=1, + bias=False, + ): + """Helper method to construct conv-batchnorm layers. + + :param kernel_size: Size of the convolution kernel. + :param padding: Zero-padding size. + :return: Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module( + "conv", + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=bias, + ), + ) + mod_list.add_module("bn", nn.BatchNorm2d(out_channels)) + return mod_list + + def reparameterize(self): + if self.infer_mode or self.stride == 1: + return + dw_kernel, dw_bias = self._get_kernel_bias_dw() + self.conv_dw = nn.Conv2d( + in_channels=self.dw_rpr_conv[0].conv.in_channels, + out_channels=self.dw_rpr_conv[0].conv.out_channels, + kernel_size=self.dw_rpr_conv[0].conv.kernel_size, + stride=self.dw_rpr_conv[0].conv.stride, + padding=self.dw_rpr_conv[0].conv.padding, + dilation=self.dw_rpr_conv[0].conv.dilation, + groups=self.dw_rpr_conv[0].conv.groups, + bias=True, + ) + self.conv_dw.weight.data = dw_kernel + self.conv_dw.bias.data = dw_bias + self.bn_dw = nn.Identity() + + # Delete un-used branches + for para in self.parameters(): + para.detach_() + if hasattr(self, "dw_rpr_conv"): + self.__delattr__("dw_rpr_conv") + if hasattr(self, "dw_rpr_scale"): + self.__delattr__("dw_rpr_scale") + if hasattr(self, "dw_rpr_skip"): + self.__delattr__("dw_rpr_skip") + + self.infer_mode = True + + def _get_kernel_bias_dw(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + :return: Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.dw_rpr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.dw_rpr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad( + kernel_scale, [pad, pad, pad, pad] + ) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.dw_rpr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor( + self.dw_rpr_skip + ) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.dw_rpr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + :param branch: + :return: Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.groups + kernel_value = torch.zeros( + ( + self.in_channels, + input_dim, + self.kernel_size, + self.kernel_size, + ), + dtype=branch.weight.dtype, + device=branch.weight.device, + ) + for i in range(self.in_channels): + kernel_value[ + i, + i % input_dim, + self.kernel_size // 2, + self.kernel_size // 2, + ] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class GhostNet(nn.Module): + def __init__( + self, + cfgs, + num_classes=1000, + width=1.0, + dropout=0.2, + block=GhostBottleneck, + args=None, + ): + super(GhostNet, self).__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + self.dropout = dropout + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + # block = block + layer_id = 0 + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + if block == GhostBottleneck: + layers.append( + block( + input_channel, + hidden_channel, + output_channel, + k, + s, + se_ratio=se_ratio, + layer_id=layer_id, + args=args, + ) + ) + input_channel = output_channel + layer_id += 1 + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width, 4) + stages.append( + nn.Sequential(ConvBnAct(input_channel, output_channel, 1)) + ) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d( + input_channel, output_channel, 1, 1, 0, bias=True + ) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = x.view(x.size(0), -1) + # if self.dropout > 0.: + # x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + x = x.squeeze() + return x + + def reparameterize(self): + for _, module in self.named_modules(): + if isinstance(module, GhostModule): + module.reparameterize() + if isinstance(module, GhostBottleneck): + module.reparameterize() + + +def ghostnetv3(**kwargs): + """ + Constructs a GhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 16, 16, 0, 1]], + # stage2 + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + # stage3 + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + # stage4 + [[3, 240, 80, 0, 2]], + [ + [3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1], + ], + # stage5 + [[5, 672, 160, 0.25, 2]], + [ + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + ], + ] + return GhostNet(cfgs, num_classes=1000, width=kwargs["width"], dropout=0.2)