diff --git a/README.md b/README.md index bd0b45b..bbfcbd1 100644 --- a/README.md +++ b/README.md @@ -167,11 +167,13 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io |EfficientNetV2|[ICML 2021](https://arxiv.org/abs/2104.00298)|`timm`|`kimm.models.EfficientNetV2*`| |GhostNet|[CVPR 2020](https://arxiv.org/abs/1911.11907)|`timm`|`kimm.models.GhostNet*`| |GhostNetV2|[NeurIPS 2022](https://arxiv.org/abs/2211.12905)|`timm`|`kimm.models.GhostNetV2*`| +|InceptionNeXt|[arXiv 2023](https://arxiv.org/abs/2303.16900)|`timm`|`kimm.models.InceptionNeXt*`| |InceptionV3|[CVPR 2016](https://arxiv.org/abs/1512.00567)|`timm`|`kimm.models.InceptionV3`| |LCNet|[arXiv 2021](https://arxiv.org/abs/2109.15099)|`timm`|`kimm.models.LCNet*`| |MobileNetV2|[CVPR 2018](https://arxiv.org/abs/1801.04381)|`timm`|`kimm.models.MobileNetV2*`| |MobileNetV3|[ICCV 2019](https://arxiv.org/abs/1905.02244)|`timm`|`kimm.models.MobileNetV3*`| |MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`| +|MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`| |RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`| |ResNet|[CVPR 2015](https://arxiv.org/abs/1512.03385)|`timm`|`kimm.models.ResNet*`| |TinyNet|[NeurIPS 2020](https://arxiv.org/abs/2010.14819)|`timm`|`kimm.models.TinyNet*`| diff --git a/kimm/__init__.py b/kimm/__init__.py index 626592d..0ef3a09 100644 --- a/kimm/__init__.py +++ b/kimm/__init__.py @@ -2,4 +2,4 @@ from kimm import models # force to add models to the registry from kimm.utils.model_registry import list_models -__version__ = "0.1.4" +__version__ = "0.1.5" diff --git a/kimm/export/export_onnx_test.py b/kimm/export/export_onnx_test.py index a4071fc..35e0d28 100644 --- a/kimm/export/export_onnx_test.py +++ b/kimm/export/export_onnx_test.py @@ -10,7 +10,7 @@ class ExportOnnxTest(testing.TestCase, parameterized.TestCase): def get_model(self): input_shape = [3, 224, 224] # channels_first - model = models.MobileNet050V3Small(include_preprocessing=False) + model = models.MobileNetV3W050Small(include_preprocessing=False) return input_shape, model @classmethod diff --git a/kimm/export/export_tflite_test.py b/kimm/export/export_tflite_test.py index fdfebc1..15d2c24 100644 --- a/kimm/export/export_tflite_test.py +++ b/kimm/export/export_tflite_test.py @@ -12,7 +12,7 @@ class ExportTFLiteTest(testing.TestCase, parameterized.TestCase): def get_model_and_representative_dataset(self): input_shape = [224, 224, 3] - model = models.MobileNet050V3Small(include_preprocessing=False) + model = models.MobileNetV3W050Small(include_preprocessing=False) def representative_dataset(): for _ in range(10): diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index f1540de..b2ec736 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -4,6 +4,7 @@ from kimm.models.densenet import * # noqa:F403 from kimm.models.efficientnet import * # noqa:F403 from kimm.models.ghostnet import * # noqa:F403 +from kimm.models.inception_next import * # noqa:F403 from kimm.models.inception_v3 import * # noqa:F403 from kimm.models.mobilenet_v2 import * # noqa:F403 from kimm.models.mobilenet_v3 import * # noqa:F403 diff --git a/kimm/models/convnext.py b/kimm/models/convnext.py index 82e3a47..271b01c 100644 --- a/kimm/models/convnext.py +++ b/kimm/models/convnext.py @@ -24,8 +24,8 @@ def apply_convnext_block( ): channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 input_channels = inputs.shape[channels_axis] - hidden_channels = int(mlp_ratio * output_channels) + x = inputs shortcut = inputs diff --git a/kimm/models/inception_next.py b/kimm/models/inception_next.py new file mode 100644 index 0000000..3d4eedc --- /dev/null +++ b/kimm/models/inception_next.py @@ -0,0 +1,393 @@ +import itertools +import typing + +import keras +from keras import backend +from keras import initializers +from keras import layers +from keras import ops + +from kimm import layers as kimm_layers +from kimm.blocks import apply_mlp_block +from kimm.models import BaseModel +from kimm.utils import add_model_to_registry + + +def apply_inception_depthwise_conv2d( + inputs, + square_kernel_size: int = 3, + band_kernel_size: int = 11, + branch_ratio: float = 0.125, + name="inception_depthwise_conv2d", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] + branch_channels = int(input_channels * branch_ratio) + split_sizes = ( + input_channels - 3 * branch_channels, + branch_channels, + branch_channels, + branch_channels, + ) + split_indices = list(itertools.accumulate(split_sizes[:-1])) + square_padding = (square_kernel_size - 1) // 2 + band_padding = (band_kernel_size - 1) // 2 + + x = inputs + + x_id, x_hw, x_w, x_h = ops.split(x, split_indices, axis=channels_axis) + x_hw = layers.ZeroPadding2D(square_padding)(x_hw) + x_hw = layers.DepthwiseConv2D( + square_kernel_size, + use_bias=True, + name=f"{name}_dwconv_hw_dwconv2d", + )(x_hw) + + x_w = layers.ZeroPadding2D((0, band_padding))(x_w) + x_w = layers.DepthwiseConv2D( + (1, band_kernel_size), + use_bias=True, + name=f"{name}_dwconv_w_dwconv2d", + )(x_w) + + x_h = layers.ZeroPadding2D((band_padding, 0))(x_h) + x_h = layers.DepthwiseConv2D( + (band_kernel_size, 1), + use_bias=True, + name=f"{name}_dwconv_h_dwconv2d", + )(x_h) + + x = layers.Concatenate(axis=channels_axis)([x_id, x_hw, x_w, x_h]) + return x + + +def apply_metanext_block( + inputs, + output_channels: int, + mlp_ratio: float = 4.0, + activation="gelu", + name="metanext_block", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + hidden_channels = int(mlp_ratio * output_channels) + + x = inputs + + x = apply_inception_depthwise_conv2d(x, name=f"{name}_token_mixer") + x = layers.BatchNormalization( + axis=channels_axis, epsilon=1e-5, name=f"{name}_norm" + )(x) + x = apply_mlp_block( + x, + hidden_channels, + output_channels, + activation, + use_bias=True, + use_conv_mlp=True, + name=f"{name}_mlp", + ) + x = kimm_layers.LayerScale( + axis=channels_axis, + initializer=initializers.Constant(1e-6), + name=f"{name}_layerscale", + )(x) + + x = layers.Add()([x, inputs]) + return x + + +def apply_metanext_stage( + inputs, + depth: int, + output_channels: int, + strides: int, + mlp_ratio: float = 4, + activation="gelu", + name="convnext_stage", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + + x = inputs + + # Downsample + if strides > 1: + x = layers.BatchNormalization( + axis=channels_axis, + momentum=0.9, + epsilon=1e-5, + name=f"{name}_downsample_0", + )(x) + x = layers.Conv2D( + output_channels, + 2, + strides, + use_bias=True, + name=f"{name}_downsample_1_conv2d", + )(x) + + # Blocks + for i in range(depth): + x = apply_metanext_block( + x, + output_channels, + mlp_ratio=mlp_ratio, + activation=activation, + name=f"{name}_blocks_{i}", + ) + return x + + +@keras.saving.register_keras_serializable(package="kimm") +class InceptionNeXt(BaseModel): + available_feature_keys = [ + "STEM_S4", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + + def __init__( + self, + depths: typing.Sequence[int] = [3, 3, 9, 3], + hidden_channels: typing.Sequence[int] = [96, 192, 384, 768], + mlp_ratios: typing.Sequence[float] = [4, 4, 4, 3], + activation: str = "gelu", + **kwargs, + ): + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs, 224) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + require_flatten=self._include_top, + ) + x = inputs + + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # Stem + x = layers.Conv2D( + hidden_channels[0], 4, 4, use_bias=True, name="stem_0_conv2d" + )(x) + x = layers.BatchNormalization( + axis=channels_axis, momentum=0.9, epsilon=1e-5, name="stem_1" + )(x) + features["STEM_S4"] = x + + # Blocks (4 stages) + current_stride = 4 + for i in range(4): + strides = 2 if i > 0 else 1 + x = apply_metanext_stage( + x, + depths[i], + hidden_channels[i], + strides, + mlp_ratios[i], + activation=activation, + name=f"stages_{i}", + ) + current_stride *= strides + # Add feature + features[f"BLOCK{i}_S{current_stride}"] = x + + # Head + x = self.build_head(x) + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.depths = depths + self.hidden_channels = hidden_channels + self.mlp_ratios = mlp_ratios + self.activation = activation + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + input_channels = inputs.shape[channels_axis] + hidden_channels = int(input_channels * 3.0) + + x = inputs + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dense(hidden_channels, use_bias=True, name="head_fc1")(x) + x = layers.Activation("gelu")(x) + x = layers.LayerNormalization(axis=-1, epsilon=1e-6, name="head_norm")( + x + ) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "depths": self.depths, + "hidden_channels": self.hidden_channels, + "mlp_ratios": self.mlp_ratios, + "activation": self.activation, + } + ) + return config + + def fix_config(self, config: typing.Dict): + unused_kwargs = [ + "depths", + "hidden_channels", + "mlp_ratios", + "activation", + ] + for k in unused_kwargs: + config.pop(k, None) + return config + + +""" +Model Definition +""" + + +class InceptionNeXtTiny(InceptionNeXt): + available_weights = [ + ( + "imagenet", + InceptionNeXt.default_origin, + "inceptionnexttiny_inception_next_tiny.sail_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "InceptionNeXtTiny", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 9, 3), + (96, 192, 384, 768), + (4, 4, 4, 3), + "gelu", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class InceptionNeXtSmall(InceptionNeXt): + available_weights = [ + ( + "imagenet", + InceptionNeXt.default_origin, + "inceptionnextsmall_inception_next_small.sail_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "InceptionNeXtSmall", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 27, 3), + (96, 192, 384, 768), + (4, 4, 4, 3), + "gelu", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class InceptionNeXtBase(InceptionNeXt): + available_weights = [ + ( + "imagenet", + InceptionNeXt.default_origin, + "inceptionnextbase_inception_next_base.sail_in1k_384.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "InceptionNeXtBase", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 27, 3), + (128, 256, 512, 1024), + (4, 4, 4, 3), + "gelu", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + default_size=384, + **kwargs, + ) + + +add_model_to_registry(InceptionNeXtTiny, "imagenet") +add_model_to_registry(InceptionNeXtSmall, "imagenet") +add_model_to_registry(InceptionNeXtBase, "imagenet") diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index f5150f1..16489b9 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -150,7 +150,7 @@ def fix_config(self, config): """ -class MobileNet050V2(MobileNetV2): +class MobileNetV2W050(MobileNetV2): available_weights = [ ( "imagenet", @@ -171,7 +171,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "default", - name: str = "MobileNet050V2", + name: str = "MobileNetV2W050", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -194,7 +194,7 @@ def __init__( ) -class MobileNet100V2(MobileNetV2): +class MobileNetV2W100(MobileNetV2): available_weights = [ ( "imagenet", @@ -215,7 +215,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "default", - name: str = "MobileNet100V2", + name: str = "MobileNetV2W100", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -238,7 +238,7 @@ def __init__( ) -class MobileNet110V2(MobileNetV2): +class MobileNetV2W110(MobileNetV2): available_weights = [ ( "imagenet", @@ -259,7 +259,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "default", - name: str = "MobileNet110V2", + name: str = "MobileNetV2W110", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -282,7 +282,7 @@ def __init__( ) -class MobileNet120V2(MobileNetV2): +class MobileNetV2W120(MobileNetV2): available_weights = [ ( "imagenet", @@ -303,7 +303,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "default", - name: str = "MobileNet120V2", + name: str = "MobileNetV2W120", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -326,7 +326,7 @@ def __init__( ) -class MobileNet140V2(MobileNetV2): +class MobileNetV2W140(MobileNetV2): available_weights = [ ( "imagenet", @@ -347,7 +347,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "default", - name: str = "MobileNet140V2", + name: str = "MobileNetV2W140", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -370,8 +370,8 @@ def __init__( ) -add_model_to_registry(MobileNet050V2, "imagenet") -add_model_to_registry(MobileNet100V2, "imagenet") -add_model_to_registry(MobileNet110V2, "imagenet") -add_model_to_registry(MobileNet120V2, "imagenet") -add_model_to_registry(MobileNet140V2, "imagenet") +add_model_to_registry(MobileNetV2W050, "imagenet") +add_model_to_registry(MobileNetV2W100, "imagenet") +add_model_to_registry(MobileNetV2W110, "imagenet") +add_model_to_registry(MobileNetV2W120, "imagenet") +add_model_to_registry(MobileNetV2W140, "imagenet") diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index 5d61396..52791af 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -303,7 +303,7 @@ def fix_config(self, config): """ -class MobileNet050V3Small(MobileNetV3): +class MobileNetV3W050Small(MobileNetV3): available_feature_keys = [ "STEM_S2", *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], @@ -328,7 +328,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "small", - name: str = "MobileNet050V3Small", + name: str = "MobileNetV3W050Small", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -351,7 +351,7 @@ def __init__( ) -class MobileNet075V3Small(MobileNetV3): +class MobileNetV3W075Small(MobileNetV3): available_feature_keys = [ "STEM_S2", *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], @@ -376,7 +376,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "small", - name: str = "MobileNet075V3Small", + name: str = "MobileNetV3W075Small", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -399,7 +399,7 @@ def __init__( ) -class MobileNet100V3Small(MobileNetV3): +class MobileNetV3W100Small(MobileNetV3): available_feature_keys = [ "STEM_S2", *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], @@ -424,7 +424,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "small", - name: str = "MobileNet100V3Small", + name: str = "MobileNetV3W100Small", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -447,7 +447,7 @@ def __init__( ) -class MobileNet100V3SmallMinimal(MobileNetV3): +class MobileNetV3W100SmallMinimal(MobileNetV3): available_feature_keys = [ "STEM_S2", *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], @@ -475,7 +475,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "small", - name: str = "MobileNet100V3SmallMinimal", + name: str = "MobileNetV3W100SmallMinimal", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -502,7 +502,7 @@ def __init__( ) -class MobileNet100V3Large(MobileNetV3): +class MobileNetV3W100Large(MobileNetV3): available_feature_keys = [ "STEM_S2", *[ @@ -533,7 +533,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "large", - name: str = "MobileNet100V3Large", + name: str = "MobileNetV3W100Large", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -566,7 +566,7 @@ def build_preprocessing(self, inputs, mode="imagenet"): return super().build_preprocessing(inputs, mode) -class MobileNet100V3LargeMinimal(MobileNetV3): +class MobileNetV3W100LargeMinimal(MobileNetV3): available_feature_keys = [ "STEM_S2", *[ @@ -597,7 +597,7 @@ def __init__( classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", config: typing.Union[str, typing.List] = "large", - name: str = "MobileNet100V3LargeMinimal", + name: str = "MobileNetV3W100LargeMinimal", **kwargs, ): kwargs = self.fix_config(kwargs) @@ -860,12 +860,12 @@ def __init__( ) -add_model_to_registry(MobileNet050V3Small, "imagenet") -add_model_to_registry(MobileNet075V3Small, "imagenet") -add_model_to_registry(MobileNet100V3Small, "imagenet") -add_model_to_registry(MobileNet100V3SmallMinimal, "imagenet") -add_model_to_registry(MobileNet100V3Large, "imagenet") -add_model_to_registry(MobileNet100V3LargeMinimal, "imagenet") +add_model_to_registry(MobileNetV3W050Small, "imagenet") +add_model_to_registry(MobileNetV3W075Small, "imagenet") +add_model_to_registry(MobileNetV3W100Small, "imagenet") +add_model_to_registry(MobileNetV3W100SmallMinimal, "imagenet") +add_model_to_registry(MobileNetV3W100Large, "imagenet") +add_model_to_registry(MobileNetV3W100LargeMinimal, "imagenet") add_model_to_registry(LCNet035) add_model_to_registry(LCNet050, "imagenet") add_model_to_registry(LCNet075, "imagenet") diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 583706d..ba3acd6 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -8,13 +8,14 @@ from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_inverted_residual_block +from kimm.blocks import apply_mlp_block from kimm.blocks import apply_transformer_block from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible -# type, repeat, channels, strides, expansion_ratio, transformer_dim, -# transformer_depth, patch_size +# type, repeat, kernel_size, channels, strides, expansion_ratio, +# transformer_dim, transformer_depth, patch_size DEFAULT_V1_S_CONFIG = [ ["ir", 1, 3, 32, 1, 4.0, None, None, None], ["ir", 3, 3, 64, 2, 4.0, None, None, None], @@ -36,6 +37,13 @@ ["mobilevit", 1, 3, 64, 2, 2.0, 80, 4, 2], ["mobilevit", 1, 3, 80, 2, 2.0, 96, 3, 2], ] +DEFAULT_V2_CONFIG = [ + ["ir", 1, 3, 64, 1, 2.0, None, None, None], + ["ir", 2, 3, 128, 2, 2.0, None, None, None], + ["mobilevitv2", 1, 3, 256, 2, 2.0, 128, 2, 2], + ["mobilevitv2", 1, 3, 384, 2, 2.0, 192, 4, 2], + ["mobilevitv2", 1, 3, 512, 2, 2.0, 256, 3, 2], +] def unfold(inputs, patch_size): @@ -89,8 +97,6 @@ def apply_mobilevit_block( transformer_depth: int = 2, patch_size: int = 8, num_heads: int = 4, - projection_dropout_rate=0.0, - attention_dropout_rate=0.0, activation="swish", transformer_activation="swish", fusion: bool = True, @@ -134,8 +140,6 @@ def apply_mobilevit_block( mlp_ratio, True, False, - attention_dropout_rate=attention_dropout_rate, - projection_dropout_rate=projection_dropout_rate, activation=transformer_activation, name=f"{name}_transformer_{i}", ) @@ -170,6 +174,212 @@ def apply_mobilevit_block( return x +def unfold_v2(inputs, patch_size): + x = inputs + + if backend.image_data_format() == "channels_last": + h, w, c = x.shape[-3], x.shape[-2], x.shape[-1] + else: + c, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + + new_h, new_w = ( + math.ceil(h / patch_size) * patch_size, + math.ceil(w / patch_size) * patch_size, + ) + num_patches_h = new_h // patch_size + num_patches_w = new_w // patch_size + num_patches = num_patches_h * num_patches_w + + if backend.image_data_format() == "channels_last": + # [B, H, W, C] -> [B, P, N, C] + x = ops.reshape( + x, [-1, num_patches_h, patch_size, num_patches_w, patch_size, c] + ) + x = ops.transpose(x, [0, 2, 4, 1, 3, 5]) + x = ops.reshape(x, [-1, patch_size * patch_size, num_patches, c]) + else: + # [B, C, H, W] -> [B, C, P, N] + x = ops.reshape( + x, [-1, c, num_patches_h, patch_size, num_patches_w, patch_size] + ) + x = ops.transpose(x, [0, 1, 3, 5, 2, 4]) + x = ops.reshape(x, [-1, c, patch_size * patch_size, num_patches]) + return x + + +def fold_v2(inputs, h, w, c, patch_size): + x = inputs + + new_h, new_w = ( + math.ceil(h / patch_size) * patch_size, + math.ceil(w / patch_size) * patch_size, + ) + num_patches_h = new_h // patch_size + num_patches_w = new_w // patch_size + if backend.image_data_format() == "channels_last": + # [B, P, N, C] -> [B, H, W, C] + x = ops.reshape( + x, [-1, patch_size, patch_size, num_patches_h, num_patches_w, c] + ) + x = ops.transpose(x, [0, 3, 1, 4, 2, 5]) + x = ops.reshape( + x, [-1, num_patches_h * patch_size, num_patches_w * patch_size, c] + ) + else: + # [B, C, P, N] -> [B, C, H, W] + x = ops.reshape( + x, [-1, c, patch_size, patch_size, num_patches_h, num_patches_w] + ) + x = ops.transpose(x, [0, 1, 4, 2, 5, 3]) + x = ops.reshape( + x, [-1, c, num_patches_h * patch_size, num_patches_w * patch_size] + ) + return x + + +def apply_linear_self_attention_block( + inputs, dim: int, use_bias=True, name="linear_self_attention_block" +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + num_patch_axis = ( + -2 if backend.image_data_format() == "channels_last" else -1 + ) + + x = inputs + + # [B, P, N, C] -> [B, P, N, h + 2d] + # Project x into query, key and value + # Query: [B, P, N, 1] + # Value & Key: [B, P, N, d] + x = layers.Conv2D( + 1 + (2 * dim), 1, use_bias=use_bias, name=f"{name}_qkv_proj_conv2d" + )(x) + query, key, value = ops.split(x, [1, 1 + dim], axis=channels_axis) + + # Apply softmax along N dimension + context_scores = ops.softmax(query, axis=num_patch_axis) + + # Compute context vector + # [B, P, N, d] x [B, P, N, 1] -> [B, P, N, d] -> [B, P, 1, d] + context_vector = layers.Multiply()([key, context_scores]) + context_vector = ops.sum(context_vector, axis=num_patch_axis, keepdims=True) + + # Combine context vector with values + # [B, P, N, d] * [B, P, 1, d] -> [B, P, N, d] + out = layers.ReLU()(value) + out = layers.Multiply()([out, context_vector]) + out = layers.Conv2D( + dim, 1, use_bias=use_bias, name=f"{name}_out_proj_conv2d" + )(out) + return out + + +def apply_linear_transformer_block( + inputs, + dim: int, + mlp_ratio: float = 2.0, + activation="swish", + name="linear_transformer_block", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + x = inputs + + # Self-attention + x = layers.GroupNormalization( + 1, axis=channels_axis, epsilon=1e-5, name=f"{name}_norm1" + )(x) + x = apply_linear_self_attention_block( + x, dim, use_bias=True, name=f"{name}_attn" + ) + x = layers.Add()([inputs, x]) + + # Feedforward network + residual = x + x = layers.GroupNormalization( + 1, axis=channels_axis, epsilon=1e-5, name=f"{name}_norm2" + )(x) + x = apply_mlp_block( + x, + int(dim * mlp_ratio), + activation=activation, + use_bias=True, + use_conv_mlp=True, + name=f"{name}_mlp", + ) + x = layers.Add()([residual, x]) + return x + + +def apply_mobilevitv2_block( + inputs, + output_channels: int, + kernel_size: int = 3, + strides: int = 1, + expansion_ratio: float = 1.0, + mlp_ratio: float = 2.0, + transformer_dim: typing.Optional[int] = None, + transformer_depth: int = 2, + patch_size: int = 8, + activation="swish", + transformer_activation="swish", + name="mobilevitv2_block", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] + transformer_dim = transformer_dim or make_divisible( + input_channels * expansion_ratio + ) + + x = inputs + + # Local representation + x = apply_conv2d_block( + x, + input_channels, + kernel_size, + strides, + activation=activation, + use_depthwise=True, + name=f"{name}_conv_kxk", + ) + x = layers.Conv2D( + transformer_dim, 1, use_bias=False, name=f"{name}_conv_1x1" + )(x) + + # Unfold (feature map -> patches) + if backend.image_data_format() == "channels_last": + h, w, c = x.shape[-3], x.shape[-2], x.shape[-1] + else: + c, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = unfold_v2(x, patch_size) + + # Global representations: + for i in range(transformer_depth): + x = apply_linear_transformer_block( + x, + transformer_dim, + mlp_ratio, + activation=transformer_activation, + name=f"{name}_transformer_{i}", + ) + x = layers.GroupNormalization( + 1, axis=channels_axis, epsilon=1e-5, name=f"{name}_norm" + )(x) + + # Fold (patch -> feature map) + x = fold_v2(x, h, w, c, patch_size) + + x = apply_conv2d_block( + x, + output_channels, + 1, + 1, + activation=None, + name=f"{name}_conv_proj", + ) + return x + + @keras.saving.register_keras_serializable(package="kimm") class MobileViT(BaseModel): available_feature_keys = [ @@ -300,12 +510,132 @@ def fix_config(self, config): return config -class MobileViTS(MobileViT): +@keras.saving.register_keras_serializable(package="kimm") +class MobileViTV2(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(5), [2, 4, 8, 16, 32])], + ] + + def __init__( + self, + multiplier: float = 1.0, + activation: str = "swish", + config: str = "v2", + **kwargs, + ): + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + + _available_configs = ["v2"] + if config == "v2": + _config = DEFAULT_V2_CONFIG + else: + raise ValueError( + f"config must be one of {_available_configs} using string. " + f"Received: config={config}" + ) + + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs, 256) + + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + static_shape=True, + ) + x = inputs + + x = self.build_preprocessing(x, "0_1") + + # Prepare feature extraction + features = {} + + # stem + x = apply_conv2d_block( + x, + int(32 * multiplier), + 3, + 2, + activation=activation, + name="stem", + ) + features["STEM_S2"] = x + + # blocks + current_stride = 2 + for current_block_idx, cfg in enumerate(_config): + ( + block_type, + r, + k, + c, + s, + e, + _, + transformer_depth, + patch_size, + ) = cfg + c = int(c * multiplier) + # always apply inverted_residual_block + for current_layer_idx in range(r): + s = s if current_layer_idx == 0 else 1 + name = f"stages_{current_block_idx}_{current_layer_idx}" + x = apply_inverted_residual_block( + x, c, k, 1, 1, s, e, activation=activation, name=name + ) + current_stride *= s + if block_type == "mobilevitv2": + name = f"stages_{current_block_idx}_{current_layer_idx + 1}" + x = apply_mobilevitv2_block( + x, + c, + k, + 1, + 0.5, + mlp_ratio=2.0, + transformer_depth=transformer_depth, + patch_size=patch_size, + activation=activation, + transformer_activation=activation, + name=name, + ) + features[f"BLOCK{current_block_idx}_S{current_stride}"] = x + + # Head + x = self.build_head(x) + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.multiplier = multiplier + self.activation = activation + self.config = config + + def get_config(self): + config = super().get_config() + config.update( + { + "multiplier": self.multiplier, + "activation": self.activation, + "config": self.config, + } + ) + return config + + def fix_config(self, config): + unused_kwargs = ["multiplier", "activation", "config"] + for k in unused_kwargs: + config.pop(k, None) + return config + + +class MobileViTXXS(MobileViT): available_weights = [ ( "imagenet", MobileViT.default_origin, - "mobilevits_mobilevit_s.cvnets_in1k.keras", + "mobilevitxxs_mobilevit_xxs.cvnets_in1k.keras", ) ] @@ -320,14 +650,14 @@ def __init__( classes: int = 1000, classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", - config: str = "v1_s", - name="MobileViTS", + config: str = "v1_xxs", + name="MobileViTXXS", **kwargs, ): kwargs = self.fix_config(kwargs) super().__init__( 16, - 640, + 320, "swish", config, input_tensor=input_tensor, @@ -388,12 +718,12 @@ def __init__( ) -class MobileViTXXS(MobileViT): +class MobileViTS(MobileViT): available_weights = [ ( "imagenet", MobileViT.default_origin, - "mobilevitxxs_mobilevit_xxs.cvnets_in1k.keras", + "mobilevits_mobilevit_s.cvnets_in1k.keras", ) ] @@ -408,14 +738,315 @@ def __init__( classes: int = 1000, classifier_activation: str = "softmax", weights: typing.Optional[str] = "imagenet", - config: str = "v1_xxs", - name="MobileViTXXS", + config: str = "v1_s", + name="MobileViTS", **kwargs, ): kwargs = self.fix_config(kwargs) super().__init__( 16, - 320, + 640, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W050(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w050_mobilevitv2_050.cvnets_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W050", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 0.5, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W075(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w075_mobilevitv2_075.cvnets_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W075", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 0.75, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W100(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w100_mobilevitv2_100.cvnets_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W100", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 1.0, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W125(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w125_mobilevitv2_125.cvnets_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W125", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 1.25, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W150(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w150_mobilevitv2_150.cvnets_in22k_ft_in1k_384.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W150", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 1.5, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W175(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w175_mobilevitv2_175.cvnets_in22k_ft_in1k_384.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W175", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 1.75, + "swish", + config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class MobileViTV2W200(MobileViTV2): + available_weights = [ + ( + "imagenet", + MobileViTV2.default_origin, + "mobilevitv2w200_mobilevitv2_200.cvnets_in22k_ft_in1k_384.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.1, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + config: str = "v2", + name="MobileViTV2W200", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 2.0, "swish", config, input_tensor=input_tensor, @@ -432,6 +1063,13 @@ def __init__( ) -add_model_to_registry(MobileViTS, "imagenet") -add_model_to_registry(MobileViTXS, "imagenet") add_model_to_registry(MobileViTXXS, "imagenet") +add_model_to_registry(MobileViTXS, "imagenet") +add_model_to_registry(MobileViTS, "imagenet") +add_model_to_registry(MobileViTV2W050, "imagenet") +add_model_to_registry(MobileViTV2W075, "imagenet") +add_model_to_registry(MobileViTV2W100, "imagenet") +add_model_to_registry(MobileViTV2W125, "imagenet") +add_model_to_registry(MobileViTV2W150, "imagenet") +add_model_to_registry(MobileViTV2W175, "imagenet") +add_model_to_registry(MobileViTV2W200, "imagenet") diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index 73b418d..d0a6630 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -138,6 +138,19 @@ ("BLOCK7_S32", [1, 7, 7, 160]), ], ), + # inception_next + ( + kimm_models.InceptionNeXtTiny.__name__, + kimm_models.InceptionNeXtTiny, + 224, + [ + ("STEM_S4", [1, 56, 56, 96]), + ("BLOCK0_S4", [1, 56, 56, 96]), + ("BLOCK1_S8", [1, 28, 28, 192]), + ("BLOCK2_S16", [1, 14, 14, 384]), + ("BLOCK3_S32", [1, 7, 7, 768]), + ], + ), # inception_v3 ( kimm_models.InceptionV3.__name__, @@ -153,8 +166,8 @@ ), # mobilenet_v2 ( - kimm_models.MobileNet050V2.__name__, - kimm_models.MobileNet050V2, + kimm_models.MobileNetV2W050.__name__, + kimm_models.MobileNetV2W050, 224, [ ("STEM_S2", [1, 112, 112, make_divisible(32 * 0.5)]), @@ -164,18 +177,6 @@ ("BLOCK5_S32", [1, 7, 7, make_divisible(160 * 0.5)]), ], ), - ( - kimm_models.MobileNet100V2.__name__, - kimm_models.MobileNet100V2, - 224, - [ - ("STEM_S2", [1, 112, 112, make_divisible(32 * 1.0)]), - ("BLOCK1_S4", [1, 56, 56, make_divisible(24 * 1.0)]), - ("BLOCK2_S8", [1, 28, 28, make_divisible(32 * 1.0)]), - ("BLOCK3_S16", [1, 14, 14, make_divisible(64 * 1.0)]), - ("BLOCK5_S32", [1, 7, 7, make_divisible(160 * 1.0)]), - ], - ), # mobilenet_v3 ( kimm_models.LCNet100.__name__, @@ -190,32 +191,20 @@ ], ), ( - kimm_models.MobileNet100V3Large.__name__, - kimm_models.MobileNet100V3Large, + kimm_models.MobileNetV3W050Small.__name__, + kimm_models.MobileNetV3W050Small, 224, [ - ("STEM_S2", [1, 112, 112, make_divisible(16 * 1.0)]), - ("BLOCK1_S4", [1, 56, 56, make_divisible(24 * 1.0)]), - ("BLOCK2_S8", [1, 28, 28, make_divisible(40 * 1.0)]), - ("BLOCK3_S16", [1, 14, 14, make_divisible(80 * 1.0)]), - ("BLOCK5_S32", [1, 7, 7, make_divisible(160 * 1.0)]), - ], - ), - ( - kimm_models.MobileNet100V3Small.__name__, - kimm_models.MobileNet100V3Small, - 224, - [ - ("STEM_S2", [1, 112, 112, make_divisible(16 * 1.0)]), - ("BLOCK0_S4", [1, 56, 56, make_divisible(16 * 1.0)]), - ("BLOCK1_S8", [1, 28, 28, make_divisible(24 * 1.0)]), - ("BLOCK2_S16", [1, 14, 14, make_divisible(40 * 1.0)]), - ("BLOCK4_S32", [1, 7, 7, make_divisible(96 * 1.0)]), + ("STEM_S2", [1, 112, 112, 16]), + ("BLOCK0_S4", [1, 56, 56, 8]), + ("BLOCK1_S8", [1, 28, 28, 16]), + ("BLOCK2_S16", [1, 14, 14, 24]), + ("BLOCK4_S32", [1, 7, 7, 48]), ], ), ( - kimm_models.MobileNet100V3SmallMinimal.__name__, - kimm_models.MobileNet100V3SmallMinimal, + kimm_models.MobileNetV3W100SmallMinimal.__name__, + kimm_models.MobileNetV3W100SmallMinimal, 224, [ ("STEM_S2", [1, 112, 112, make_divisible(16 * 1.0)]), @@ -238,16 +227,17 @@ ("BLOCK4_S32", [1, 8, 8, 160]), ], ), + # mobilevitv2 ( - kimm_models.MobileViTXS.__name__, - kimm_models.MobileViTXS, + kimm_models.MobileViTV2W050.__name__, + kimm_models.MobileViTV2W050, 256, [ ("STEM_S2", [1, 128, 128, 16]), - ("BLOCK1_S4", [1, 64, 64, 48]), - ("BLOCK2_S8", [1, 32, 32, 64]), - ("BLOCK3_S16", [1, 16, 16, 80]), - ("BLOCK4_S32", [1, 8, 8, 96]), + ("BLOCK1_S4", [1, 64, 64, 64]), + ("BLOCK2_S8", [1, 32, 32, 128]), + ("BLOCK3_S16", [1, 16, 16, 192]), + ("BLOCK4_S32", [1, 8, 8, 256]), ], ), # regnet diff --git a/requirements.txt b/requirements.txt index f54fabf..37576e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ # "jax[cuda12_local]" # Following is for github runner -tf-nightly-cpu==2.16.0.dev20240108 +tf-nightly-cpu==2.16.0.dev20240101 --extra-index-url https://download.pytorch.org/whl/cpu torch>=2.1.0 @@ -18,4 +18,4 @@ torchvision>=0.16.0 jax[cpu] -keras +keras>=3.0.4 diff --git a/shell/export.sh b/shell/export.sh index e788389..7ec49b8 100755 --- a/shell/export.sh +++ b/shell/export.sh @@ -9,6 +9,7 @@ python3 -m tools.convert_convnext_from_timm python3 -m tools.convert_densenet_from_timm python3 -m tools.convert_efficientnet_from_timm python3 -m tools.convert_ghostnet_from_timm +python3 -m tools.convert_inception_next_from_timm python3 -m tools.convert_inception_v3_from_timm python3 -m tools.convert_mobilenet_v2_from_timm python3 -m tools.convert_mobilenet_v3_from_timm diff --git a/tools/convert_convmixer_from_timm.py b/tools/convert_convmixer_from_timm.py index cbff78a..f637d2b 100644 --- a/tools/convert_convmixer_from_timm.py +++ b/tools/convert_convmixer_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_convnext_from_timm.py b/tools/convert_convnext_from_timm.py index c819155..4c061e8 100644 --- a/tools/convert_convnext_from_timm.py +++ b/tools/convert_convnext_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_densenet_from_timm.py b/tools/convert_densenet_from_timm.py index 5d882a7..fa17d5a 100644 --- a/tools/convert_densenet_from_timm.py +++ b/tools/convert_densenet_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_efficientnet_from_timm.py b/tools/convert_efficientnet_from_timm.py index 3af0f24..040354a 100644 --- a/tools/convert_efficientnet_from_timm.py +++ b/tools/convert_efficientnet_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py index 7ffd688..25587ac 100644 --- a/tools/convert_ghostnet_from_timm.py +++ b/tools/convert_ghostnet_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_inception_next_from_timm.py b/tools/convert_inception_next_from_timm.py new file mode 100644 index 0000000..893f57a --- /dev/null +++ b/tools/convert_inception_next_from_timm.py @@ -0,0 +1,141 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" + +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import inception_next +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "inception_next_tiny.sail_in1k", + "inception_next_small.sail_in1k", + "inception_next_base.sail_in1k_384", +] +keras_model_classes = [ + inception_next.InceptionNeXtTiny, + inception_next.InceptionNeXtSmall, + inception_next.InceptionNeXtBase, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + # prevent gamma to be replaced + is_layerscale = False + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + + # stem + torch_name = torch_name.replace("stem.0.conv2d.kernel", "stem.0.weight") + torch_name = torch_name.replace("stem.0.conv2d.bias", "stem.0.bias") + + # blocks + torch_name = torch_name.replace("dwconv2d.", "") + torch_name = torch_name.replace("conv2d.", "") + torch_name = torch_name.replace("conv.dw", "conv_dw") + if "layerscale" in torch_name: + is_layerscale = True + torch_name = torch_name.replace("layerscale.", "") + torch_name = torch_name.replace("token.mixer", "token_mixer") + torch_name = torch_name.replace("dwconv.hw.", "dwconv_hw.") + torch_name = torch_name.replace("dwconv.w.", "dwconv_w.") + torch_name = torch_name.replace("dwconv.h.", "dwconv_h.") + # head + torch_name = torch_name.replace("classifier", "head.fc2") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + if not is_layerscale: + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_layerscale: + assign_weights(keras_name, keras_weight, torch_weights) + elif is_same_weights( + keras_name, keras_weight, torch_name, torch_weights + ): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_inception_v3_from_timm.py b/tools/convert_inception_v3_from_timm.py index 5129b19..4851084 100644 --- a/tools/convert_inception_v3_from_timm.py +++ b/tools/convert_inception_v3_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py index ec55782..136c5a5 100644 --- a/tools/convert_mobilenet_v2_from_timm.py +++ b/tools/convert_mobilenet_v2_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras @@ -23,11 +24,11 @@ "mobilenetv2_140.ra_in1k", ] keras_model_classes = [ - mobilenet_v2.MobileNet050V2, - mobilenet_v2.MobileNet100V2, - mobilenet_v2.MobileNet110V2, - mobilenet_v2.MobileNet120V2, - mobilenet_v2.MobileNet140V2, + mobilenet_v2.MobileNetV2W050, + mobilenet_v2.MobileNetV2W100, + mobilenet_v2.MobileNetV2W110, + mobilenet_v2.MobileNetV2W120, + mobilenet_v2.MobileNetV2W140, ] for timm_model_name, keras_model_class in zip( diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py index 2663450..df426d1 100644 --- a/tools/convert_mobilenet_v3_from_timm.py +++ b/tools/convert_mobilenet_v3_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras @@ -27,12 +28,12 @@ "lcnet_100.ra2_in1k", ] keras_model_classes = [ - mobilenet_v3.MobileNet050V3Small, - mobilenet_v3.MobileNet075V3Small, - mobilenet_v3.MobileNet100V3SmallMinimal, - mobilenet_v3.MobileNet100V3Small, - mobilenet_v3.MobileNet100V3Large, - mobilenet_v3.MobileNet100V3LargeMinimal, + mobilenet_v3.MobileNetV3W050Small, + mobilenet_v3.MobileNetV3W075Small, + mobilenet_v3.MobileNetV3W100SmallMinimal, + mobilenet_v3.MobileNetV3W100Small, + mobilenet_v3.MobileNetV3W100Large, + mobilenet_v3.MobileNetV3W100LargeMinimal, mobilenet_v3.LCNet050, mobilenet_v3.LCNet075, mobilenet_v3.LCNet100, diff --git a/tools/convert_mobilevit_from_timm.py b/tools/convert_mobilevit_from_timm.py index 5000e74..a90b9b0 100644 --- a/tools/convert_mobilevit_from_timm.py +++ b/tools/convert_mobilevit_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras @@ -16,14 +17,28 @@ from kimm.utils.timm_utils import separate_torch_state_dict timm_model_names = [ - "mobilevit_s.cvnets_in1k", - "mobilevit_xs.cvnets_in1k", "mobilevit_xxs.cvnets_in1k", + "mobilevit_xs.cvnets_in1k", + "mobilevit_s.cvnets_in1k", + "mobilevitv2_050.cvnets_in1k", + "mobilevitv2_075.cvnets_in1k", + "mobilevitv2_100.cvnets_in1k", + "mobilevitv2_125.cvnets_in1k", + "mobilevitv2_150.cvnets_in22k_ft_in1k_384", + "mobilevitv2_175.cvnets_in22k_ft_in1k_384", + "mobilevitv2_200.cvnets_in22k_ft_in1k_384", ] keras_model_classes = [ - mobilevit.MobileViTS, - mobilevit.MobileViTXS, mobilevit.MobileViTXXS, + mobilevit.MobileViTXS, + mobilevit.MobileViTS, + mobilevit.MobileViTV2W050, + mobilevit.MobileViTV2W075, + mobilevit.MobileViTV2W100, + mobilevit.MobileViTV2W125, + mobilevit.MobileViTV2W150, + mobilevit.MobileViTV2W175, + mobilevit.MobileViTV2W200, ] for timm_model_name, keras_model_class in zip( @@ -42,6 +57,7 @@ input_shape=input_shape, include_preprocessing=False, classifier_activation="linear", + weights=None, ) trainable_weights, non_trainable_weights = separate_keras_weights( keras_model @@ -85,6 +101,16 @@ "conv.fusion.conv2d", "conv_fusion.conv" ) torch_name = torch_name.replace("conv.fusion.bn", "conv_fusion.bn") + # mobilevitv2 block + torch_name = torch_name.replace("conv.kxk.dwconv2d", "conv_kxk.conv") + torch_name = torch_name.replace( + "attn.qkv.qkv.proj.conv2d", "attn.qkv_proj" + ) + torch_name = torch_name.replace( + "attn.qkv.out.proj.conv2d", "attn.out_proj" + ) + torch_name = torch_name.replace("mlp.fc1.conv2d", "mlp.fc1") + torch_name = torch_name.replace("mlp.fc2.conv2d", "mlp.fc2") # final block torch_name = torch_name.replace("final.conv.conv2d", "final_conv.conv") torch_name = torch_name.replace("final.conv.bn", "final_conv.bn") diff --git a/tools/convert_regnet_from_timm.py b/tools/convert_regnet_from_timm.py index 33372bf..0343277 100644 --- a/tools/convert_regnet_from_timm.py +++ b/tools/convert_regnet_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index 4dde8bf..ff75bd7 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_vgg_from_timm.py b/tools/convert_vgg_from_timm.py index e5d1894..410bf9c 100644 --- a/tools/convert_vgg_from_timm.py +++ b/tools/convert_vgg_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras diff --git a/tools/convert_vit_from_timm.py b/tools/convert_vit_from_timm.py index f0fd93d..fd9ba2d 100644 --- a/tools/convert_vit_from_timm.py +++ b/tools/convert_vit_from_timm.py @@ -2,6 +2,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ + import os import keras