diff --git a/README.md b/README.md index 410ec69..d62b5d6 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,8 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io |EfficientNetV2|[ICML 2021](https://arxiv.org/abs/2104.00298)|`timm`|`kimm.models.EfficientNetV2*`| |GhostNet|[CVPR 2020](https://arxiv.org/abs/1911.11907)|`timm`|`kimm.models.GhostNet*`| |GhostNetV2|[NeurIPS 2022](https://arxiv.org/abs/2211.12905)|`timm`|`kimm.models.GhostNetV2*`| +|HGNet||`timm`|`kimm.models.HGNet*`| +|HGNetV2||`timm`|`kimm.models.HGNetV2*`| |InceptionNeXt|[arXiv 2023](https://arxiv.org/abs/2303.16900)|`timm`|`kimm.models.InceptionNeXt*`| |InceptionV3|[CVPR 2016](https://arxiv.org/abs/1512.00567)|`timm`|`kimm.models.InceptionV3`| |LCNet|[arXiv 2021](https://arxiv.org/abs/2109.15099)|`timm`|`kimm.models.LCNet*`| diff --git a/kimm/blocks/base_block.py b/kimm/blocks/base_block.py index 430745f..9091c7a 100644 --- a/kimm/blocks/base_block.py +++ b/kimm/blocks/base_block.py @@ -55,7 +55,8 @@ def apply_conv2d_block( if strides > 1: padding = "valid" x = layers.ZeroPadding2D( - (kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad" + ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2), + name=f"{name}_pad", )(x) if not use_depthwise: diff --git a/kimm/layers/__init__.py b/kimm/layers/__init__.py index f21f54f..e85a569 100644 --- a/kimm/layers/__init__.py +++ b/kimm/layers/__init__.py @@ -1,5 +1,6 @@ from kimm.layers.attention import Attention from kimm.layers.layer_scale import LayerScale +from kimm.layers.learnable_affine import LearnableAffine from kimm.layers.mobile_one_conv2d import MobileOneConv2D from kimm.layers.position_embedding import PositionEmbedding from kimm.layers.rep_conv2d import RepConv2D diff --git a/kimm/layers/learnable_affine.py b/kimm/layers/learnable_affine.py new file mode 100644 index 0000000..6e41b3f --- /dev/null +++ b/kimm/layers/learnable_affine.py @@ -0,0 +1,50 @@ +import keras +from keras import layers +from keras import ops + + +@keras.saving.register_keras_serializable(package="kimm") +class LearnableAffine(layers.Layer): + def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs): + super().__init__(**kwargs) + if isinstance(scale_value, int): + raise ValueError( + f"scale_value must be a integer. Received: {scale_value}" + ) + if isinstance(bias_value, int): + raise ValueError( + f"bias_value must be a integer. Received: {bias_value}" + ) + self.scale_value = scale_value + self.bias_value = bias_value + + def build(self, input_shape): + self.scale = self.add_weight( + shape=(1,), + initializer=lambda shape, dtype: ops.cast(self.scale_value, dtype), + trainable=True, + name="scale", + ) + self.bias = self.add_weight( + shape=(1,), + initializer=lambda shape, dtype: ops.cast(self.bias_value, dtype), + trainable=True, + name="bias", + ) + self.built = True + + def call(self, inputs, training=None, mask=None): + scale = ops.cast(self.scale, self.compute_dtype) + bias = ops.cast(self.bias, self.compute_dtype) + return ops.add(ops.multiply(inputs, scale), bias) + + def get_config(self): + config = super().get_config() + config.update( + { + "scale_value": self.scale_value, + "bias_value": self.bias_value, + "name": self.name, + } + ) + return config diff --git a/kimm/layers/learnable_affine_test.py b/kimm/layers/learnable_affine_test.py new file mode 100644 index 0000000..8aa2335 --- /dev/null +++ b/kimm/layers/learnable_affine_test.py @@ -0,0 +1,20 @@ +import pytest +from absl.testing import parameterized +from keras.src import testing + +from kimm.layers.learnable_affine import LearnableAffine + + +class LearnableAffineTest(testing.TestCase, parameterized.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer_scale_basic(self): + self.run_layer_test( + LearnableAffine, + init_kwargs={"scale_value": 1.0, "bias_value": 0.0}, + input_shape=(1, 10), + expected_output_shape=(1, 10), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index 382b4b4..2cc253d 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -4,6 +4,7 @@ from kimm.models.densenet import * # noqa:F403 from kimm.models.efficientnet import * # noqa:F403 from kimm.models.ghostnet import * # noqa:F403 +from kimm.models.hgnet import * # noqa:F403 from kimm.models.inception_next import * # noqa:F403 from kimm.models.inception_v3 import * # noqa:F403 from kimm.models.mobilenet_v2 import * # noqa:F403 diff --git a/kimm/models/hgnet.py b/kimm/models/hgnet.py new file mode 100644 index 0000000..338e3b8 --- /dev/null +++ b/kimm/models/hgnet.py @@ -0,0 +1,968 @@ +import typing + +import keras +from keras import backend +from keras import layers + +from kimm import layers as kimm_layers +from kimm.blocks import apply_conv2d_block +from kimm.models.base_model import BaseModel +from kimm.utils import add_model_to_registry + +DEFAULT_V1_TINY_CONFIG = dict( + stem_type="v1", + stem_channels=[48, 48, 96], + use_learnable_affine=False, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[96, 96, 224, 1, False, False, 3, 5], + stage2=[224, 128, 448, 1, True, False, 3, 5], + stage3=[448, 160, 512, 2, True, False, 3, 5], + stage4=[512, 192, 768, 1, True, False, 3, 5], +) +DEFAULT_V1_SMALL_CONFIG = dict( + stem_type="v1", + stem_channels=[64, 64, 128], + use_learnable_affine=False, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[128, 128, 256, 1, False, False, 3, 6], + stage2=[256, 160, 512, 1, True, False, 3, 6], + stage3=[512, 192, 768, 2, True, False, 3, 6], + stage4=[768, 224, 1024, 1, True, False, 3, 6], +) +DEFAULT_V1_BASE_CONFIG = dict( + stem_type="v1", + stem_channels=[96, 96, 160], + use_learnable_affine=False, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[160, 192, 320, 1, False, False, 3, 7], + stage2=[320, 224, 640, 2, True, False, 3, 7], + stage3=[640, 256, 960, 3, True, False, 3, 7], + stage4=[960, 288, 1280, 2, True, False, 3, 7], +) +DEFAULT_V2_B0_CONFIG = dict( + stem_type="v2", + stem_channels=[16, 16], + use_learnable_affine=True, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[16, 16, 64, 1, False, False, 3, 3], + stage2=[64, 32, 256, 1, True, False, 3, 3], + stage3=[256, 64, 512, 2, True, True, 5, 3], + stage4=[512, 128, 1024, 1, True, True, 5, 3], +) +DEFAULT_V2_B1_CONFIG = dict( + stem_type="v2", + stem_channels=[24, 32], + use_learnable_affine=True, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[32, 32, 64, 1, False, False, 3, 3], + stage2=[64, 48, 256, 1, True, False, 3, 3], + stage3=[256, 96, 512, 2, True, True, 5, 3], + stage4=[512, 192, 1024, 1, True, True, 5, 3], +) +DEFAULT_V2_B2_CONFIG = dict( + stem_type="v2", + stem_channels=[24, 32], + use_learnable_affine=True, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[32, 32, 96, 1, False, False, 3, 4], + stage2=[96, 64, 384, 1, True, False, 3, 4], + stage3=[384, 128, 768, 3, True, True, 5, 4], + stage4=[768, 256, 1536, 1, True, True, 5, 4], +) +DEFAULT_V2_B3_CONFIG = dict( + stem_type="v2", + stem_channels=[24, 32], + use_learnable_affine=True, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[32, 32, 128, 1, False, False, 3, 5], + stage2=[128, 64, 512, 1, True, False, 3, 5], + stage3=[512, 128, 1024, 3, True, True, 5, 5], + stage4=[1024, 256, 2048, 1, True, True, 5, 5], +) +DEFAULT_V2_B4_CONFIG = dict( + stem_type="v2", + stem_channels=[32, 48], + use_learnable_affine=False, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[48, 48, 128, 1, False, False, 3, 6], + stage2=[128, 96, 512, 1, True, False, 3, 6], + stage3=[512, 192, 1024, 3, True, True, 5, 6], + stage4=[1024, 384, 2048, 1, True, True, 5, 6], +) +DEFAULT_V2_B5_CONFIG = dict( + stem_type="v2", + stem_channels=[32, 64], + use_learnable_affine=False, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[64, 64, 128, 1, False, False, 3, 6], + stage2=[128, 128, 512, 2, True, False, 3, 6], + stage3=[512, 256, 1024, 5, True, True, 5, 6], + stage4=[1024, 512, 2048, 2, True, True, 5, 6], +) +DEFAULT_V2_B6_CONFIG = dict( + stem_type="v2", + stem_channels=[48, 96], + use_learnable_affine=False, + # input_channels, hidden_channels, output_channels, blocks, downsample, + # light_block, kernel_size, num_layers + stage1=[96, 96, 192, 2, False, False, 3, 6], + stage2=[192, 192, 512, 3, True, False, 3, 6], + stage3=[512, 384, 1024, 6, True, True, 5, 6], + stage4=[1024, 768, 2048, 3, True, True, 5, 6], +) + + +def apply_conv_bn_act_block( + inputs, + filters, + kernel_size, + strides=1, + activation="relu", + use_depthwise=False, + padding=None, + use_learnable_affine=False, + name="conv_bn_act_block", +): + x = inputs + x = apply_conv2d_block( + x, + filters, + kernel_size, + strides, + activation=activation, + use_depthwise=use_depthwise, + padding=padding, + name=name, + ) + if activation is not None and use_learnable_affine: + x = kimm_layers.LearnableAffine(name=f"{name}_lab")(x) + return x + + +def apply_stem_v1(inputs, stem_channels, name="stem_v1"): + x = inputs + for i, c in enumerate(stem_channels): + x = apply_conv_bn_act_block( + x, c, 3, strides=2 if i == 0 else 1, name=f"{name}_{i}" + ) + x = layers.ZeroPadding2D(padding=1)(x) + x = layers.MaxPooling2D(pool_size=3, strides=2)(x) + return x + + +def apply_stem_v2( + inputs, + hidden_channels, + output_channels, + use_learnable_affine=False, + name="stem_v2", +): + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + + x = inputs + x = apply_conv_bn_act_block( + x, + hidden_channels, + 3, + 2, + use_learnable_affine=use_learnable_affine, + name=f"{name}_stem1", + ) + x = layers.ZeroPadding2D(padding=((0, 1), (0, 1)))(x) + + x2 = apply_conv_bn_act_block( + x, + hidden_channels // 2, + 2, + 1, + padding="valid", + use_learnable_affine=use_learnable_affine, + name=f"{name}_stem2a", + ) + x2 = layers.ZeroPadding2D(padding=((0, 1), (0, 1)))(x2) + x2 = apply_conv_bn_act_block( + x2, + hidden_channels, + 2, + 1, + padding="valid", + use_learnable_affine=use_learnable_affine, + name=f"{name}_stem2b", + ) + + x1 = layers.MaxPooling2D(pool_size=2, strides=1)(x) + x = layers.Concatenate(axis=channels_axis)([x1, x2]) + x = apply_conv_bn_act_block( + x, + hidden_channels, + 3, + 2, + use_learnable_affine=use_learnable_affine, + name=f"{name}_stem3", + ) + x = apply_conv_bn_act_block( + x, + output_channels, + 1, + 1, + use_learnable_affine=use_learnable_affine, + name=f"{name}_stem4", + ) + return x + + +def apply_light_conv_bn_act_block( + inputs, + filters, + kernel_size, + use_learnable_affine=False, + name="light_conv_bn_act_block", +): + x = inputs + x = apply_conv_bn_act_block( + x, + filters, + 1, + activation=None, + use_learnable_affine=use_learnable_affine, + name=f"{name}_conv1", + ) + x = apply_conv_bn_act_block( + x, + filters, + kernel_size, + activation="relu", + use_depthwise=True, + use_learnable_affine=use_learnable_affine, + name=f"{name}_conv2", + ) + return x + + +def apply_ese_module(inputs, channels, name="ese_module"): + x = inputs + x = layers.GlobalAveragePooling2D(keepdims=True)(x) + x = layers.Conv2D( + channels, 1, 1, "valid", use_bias=True, name=f"{name}_conv" + )(x) + x = layers.Activation("sigmoid")(x) + x = layers.Multiply()([inputs, x]) + return x + + +def apply_high_perf_gpu_block( + inputs, + num_layers, + hidden_channels, + output_channels, + kernel_size, + add_skip=False, + use_light_block=False, + use_learnable_affine=False, + aggregation="ese", + name="high_perf_gpu_block", +): + if aggregation not in ("se", "ese"): + raise ValueError( + "aggregation must be one of ('se', 'ese'). " + f"Receviced: aggregation={aggregation}" + ) + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + + x = inputs + outputs = [x] + for i in range(num_layers): + if use_light_block: + x = apply_light_conv_bn_act_block( + x, + hidden_channels, + kernel_size, + use_learnable_affine=use_learnable_affine, + name=f"{name}_layers_{i}", + ) + else: + x = apply_conv_bn_act_block( + x, + hidden_channels, + kernel_size, + strides=1, + use_learnable_affine=use_learnable_affine, + name=f"{name}_layers_{i}", + ) + outputs.append(x) + x = layers.Concatenate(axis=channels_axis)(outputs) + if aggregation == "se": + x = apply_conv_bn_act_block( + x, + output_channels // 2, + 1, + 1, + use_learnable_affine=use_learnable_affine, + name=f"{name}_aggregation_0", + ) + x = apply_conv_bn_act_block( + x, + output_channels, + 1, + 1, + use_learnable_affine=use_learnable_affine, + name=f"{name}_aggregation_1", + ) + else: + x = apply_conv_bn_act_block( + x, + output_channels, + 1, + 1, + use_learnable_affine=use_learnable_affine, + name=f"{name}_aggregation_0", + ) + x = apply_ese_module(x, output_channels, name=f"{name}_aggregation_1") + if add_skip: + x = layers.Add()([x, inputs]) + return x + + +def apply_high_perf_gpu_stage( + inputs, + num_blocks, + num_layers, + hidden_channels, + output_channels, + kernel_size=3, + strides=2, + downsample=True, + use_light_block=False, + use_learnable_affine=False, + aggregation="ese", + name="high_perf_gpu_stage", +): + if aggregation not in ("se", "ese"): + raise ValueError( + "aggregation must be one of ('se', 'ese'). " + f"Receviced: aggregation={aggregation}" + ) + channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 + input_channels = inputs.shape[channels_axis] + + x = inputs + if downsample: + x = apply_conv_bn_act_block( + x, + input_channels, + 3, + strides, + activation=None, + use_depthwise=True, + use_learnable_affine=use_learnable_affine, + name=f"{name}_downsample", + ) + for i in range(num_blocks): + x = apply_high_perf_gpu_block( + x, + num_layers, + hidden_channels, + output_channels, + kernel_size, + add_skip=False if i == 0 else True, + use_light_block=use_light_block, + use_learnable_affine=use_learnable_affine, + aggregation=aggregation, + name=f"{name}_blocks_{i}", + ) + return x + + +@keras.saving.register_keras_serializable(package="kimm") +class HGNet(BaseModel): + available_feature_keys = [ + "STEM_S4", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + + def __init__( + self, + config: str = "v1_tiny", + **kwargs, + ): + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + + _available_configs = ["v1_tiny", "v1_small", "v1_base"] + if config == "v1_tiny": + _config = DEFAULT_V1_TINY_CONFIG + elif config == "v1_small": + _config = DEFAULT_V1_SMALL_CONFIG + elif config == "v1_base": + _config = DEFAULT_V1_BASE_CONFIG + elif config == "v2_b0": + _config = DEFAULT_V2_B0_CONFIG + elif config == "v2_b1": + _config = DEFAULT_V2_B1_CONFIG + elif config == "v2_b2": + _config = DEFAULT_V2_B2_CONFIG + elif config == "v2_b3": + _config = DEFAULT_V2_B3_CONFIG + elif config == "v2_b4": + _config = DEFAULT_V2_B4_CONFIG + elif config == "v2_b5": + _config = DEFAULT_V2_B5_CONFIG + elif config == "v2_b6": + _config = DEFAULT_V2_B6_CONFIG + else: + raise ValueError( + f"config must be one of {_available_configs} using string. " + f"Received: config={config}" + ) + + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + ) + x = inputs + + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # stem + use_learnable_affine = _config["use_learnable_affine"] + stem_channels = _config["stem_channels"] + if _config["stem_type"] == "v1": + x = apply_stem_v1(x, stem_channels, name="stem") + elif _config["stem_type"] == "v2": + x = apply_stem_v2( + x, + stem_channels[0], + stem_channels[1], + use_learnable_affine=use_learnable_affine, + name="stem", + ) + else: + raise NotImplementedError + features["STEM_S4"] = x + + # stages + current_stride = 4 + stage_config = [ + _config["stage1"], + _config["stage2"], + _config["stage3"], + _config["stage4"], + ] + for current_stage_idx, (_, h, o, b, d, light, k, n) in enumerate( + stage_config + ): + x = apply_high_perf_gpu_stage( + x, + num_blocks=b, + num_layers=n, + hidden_channels=h, + output_channels=o, + kernel_size=k, + strides=2, + downsample=d, + use_light_block=light, + use_learnable_affine=use_learnable_affine, + aggregation="ese" if _config["stem_type"] == "v1" else "se", + name=f"stages_{current_stage_idx}", + ) + if d: + current_stride *= 2 + # add feature + features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x + + # Head + x = self.build_head(x, use_learnable_affine=use_learnable_affine) + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.config = config + + def build_top( + self, + inputs, + classes: int, + classifier_activation: str, + dropout_rate: float, + use_learnable_affine: bool = False, + ): + class_expand = 2048 + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( + inputs + ) + x = layers.Conv2D( + class_expand, + 1, + 1, + "valid", + activation="relu", + use_bias=False, + name="head_last_conv_0", + )(x) + if use_learnable_affine: + x = kimm_layers.LearnableAffine(name="head_last_conv_2")(x) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Flatten()(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + + def build_head(self, inputs, use_learnable_affine=False): + x = inputs + if self._include_top: + x = self.build_top( + x, + self._classes, + self._classifier_activation, + self._dropout_rate, + use_learnable_affine=use_learnable_affine, + ) + else: + if self._pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif self._pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + return x + + def get_config(self): + config = super().get_config() + config.update({"config": self.config}) + return config + + def fix_config(self, config): + unused_kwargs = ["config"] + for k in unused_kwargs: + config.pop(k, None) + return config + + +""" +Model Definition +""" + + +class HGNetTiny(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnettiny_hgnet_tiny.ssld_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetTiny", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v1_tiny", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetSmall(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetsmall_hgnet_small.ssld_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetSmall", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v1_small", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetBase(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetbase_hgnet_base.ssld_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetBase", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v1_base", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B0(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b0_hgnetv2_b0.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B0", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b0", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B1(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b1_hgnetv2_b1.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B1", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B2(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b2_hgnetv2_b2.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B2", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B3(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b3_hgnetv2_b3.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B3", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b3", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B4(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b4_hgnetv2_b4.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B4", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b4", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B5(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b5_hgnetv2_b5.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B5", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b5", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class HGNetV2B6(HGNet): + available_weights = [ + ( + "imagenet", + HGNet.default_origin, + "hgnetv2b6_hgnetv2_b6.ssld_stage2_ft_in1k.keras", + ) + ] + + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "HGNetV2B6", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + config="v2_b6", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +add_model_to_registry(HGNetTiny, "imagenet") +add_model_to_registry(HGNetSmall, "imagenet") +add_model_to_registry(HGNetBase, "imagenet") +add_model_to_registry(HGNetV2B0, "imagenet") +add_model_to_registry(HGNetV2B1, "imagenet") +add_model_to_registry(HGNetV2B2, "imagenet") +add_model_to_registry(HGNetV2B3, "imagenet") +add_model_to_registry(HGNetV2B4, "imagenet") +add_model_to_registry(HGNetV2B5, "imagenet") +add_model_to_registry(HGNetV2B6, "imagenet") diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index 1883b80..3a676bd 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -138,6 +138,31 @@ ("BLOCK7_S32", [1, 7, 7, 160]), ], ), + # hgnet + ( + kimm_models.HGNetTiny.__name__, + kimm_models.HGNetTiny, + 224, + [ + ("STEM_S4", [1, 56, 56, 96]), + ("BLOCK0_S4", [1, 56, 56, 224]), + ("BLOCK1_S8", [1, 28, 28, 448]), + ("BLOCK2_S16", [1, 14, 14, 512]), + ("BLOCK3_S32", [1, 7, 7, 768]), + ], + ), + ( + kimm_models.HGNetV2B0.__name__, + kimm_models.HGNetV2B0, + 224, + [ + ("STEM_S4", [1, 56, 56, 16]), + ("BLOCK0_S4", [1, 56, 56, 64]), + ("BLOCK1_S8", [1, 28, 28, 256]), + ("BLOCK2_S16", [1, 14, 14, 512]), + ("BLOCK3_S32", [1, 7, 7, 1024]), + ], + ), # inception_next ( kimm_models.InceptionNeXtTiny.__name__, diff --git a/kimm/utils/timm_utils.py b/kimm/utils/timm_utils.py index c398140..1caab2e 100644 --- a/kimm/utils/timm_utils.py +++ b/kimm/utils/timm_utils.py @@ -96,6 +96,9 @@ def assign_weights( keras_weight.assign(torch_weight) elif tuple(keras_weight.shape) == tuple(torch_weight.shape): keras_weight.assign(torch_weight) + elif len(keras_weight.shape) == 0: # Deal with scalar + if len(torch_weight.shape) == 1: + keras_weight.assign(torch_weight[0]) else: raise ValueError( f"Failed to assign {keras_name}, " @@ -111,6 +114,9 @@ def is_same_weights( torch_weights: np.ndarray, ): if np.sum(keras_weights.shape) != np.sum(torch_weights.shape): + if np.sum(keras_weights.shape) == 0: # Deal with scalar + if np.sum(torch_weights.shape) == 1: + return True return False elif keras_name[-6:] == "kernel" and torch_name[-6:] != "weight": # Conv kernel diff --git a/pyproject.toml b/pyproject.toml index 6e898fd..8afe82e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,8 @@ line-length = 80 [tool.ruff] line-length = 80 -select = ["E", "W", "F"] +lint.select = ["E", "W", "F"] +lint.isort.force-single-line = true exclude = [ ".venv", ".vscode", @@ -82,13 +83,10 @@ exclude = [ "__pycache__", ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "./examples/**/*" = ["E402"] "**/__init__.py" = ["F401"] -[tool.ruff.isort] -force-single-line = true - [tool.isort] profile = "black" force_single_line = true diff --git a/requirements.txt b/requirements.txt index f7a3bb9..fbe34cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ # "jax[cuda12_local]" # Following is for github runner -tensorflow-cpu==2.16.0rc0 # TODO: wait for TF 2.16 release +tensorflow-cpu>=2.16.1 --extra-index-url https://download.pytorch.org/whl/cpu torch>=2.1.0 diff --git a/shell/export.sh b/shell/export.sh index 3958a2e..a856728 100755 --- a/shell/export.sh +++ b/shell/export.sh @@ -9,6 +9,7 @@ python3 -m tools.convert_convnext_from_timm python3 -m tools.convert_densenet_from_timm python3 -m tools.convert_efficientnet_from_timm python3 -m tools.convert_ghostnet_from_timm +python3 -m tools.convert_hgnet_from_timm python3 -m tools.convert_inception_next_from_timm python3 -m tools.convert_inception_v3_from_timm python3 -m tools.convert_mobilenet_v2_from_timm diff --git a/shell/format.sh b/shell/format.sh index 91d6fe4..a42b7c7 100755 --- a/shell/format.sh +++ b/shell/format.sh @@ -4,4 +4,4 @@ set -Eeuo pipefail base_dir=$(dirname $(dirname $0)) isort --sp "${base_dir}/pyproject.toml" . black --config "${base_dir}/pyproject.toml" . -ruff --config "${base_dir}/pyproject.toml" . +ruff check --config "${base_dir}/pyproject.toml" . diff --git a/tools/convert_hgnet_from_timm.py b/tools/convert_hgnet_from_timm.py new file mode 100644 index 0000000..8142c78 --- /dev/null +++ b/tools/convert_hgnet_from_timm.py @@ -0,0 +1,143 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" + +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import hgnet +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + # HGNet + "hgnet_tiny.ssld_in1k", + "hgnet_small.ssld_in1k", + "hgnet_base.ssld_in1k", + # HGNetV2 + "hgnetv2_b0.ssld_stage2_ft_in1k", + "hgnetv2_b1.ssld_stage2_ft_in1k", + "hgnetv2_b2.ssld_stage2_ft_in1k", + "hgnetv2_b3.ssld_stage2_ft_in1k", + "hgnetv2_b4.ssld_stage2_ft_in1k", + "hgnetv2_b5.ssld_stage2_ft_in1k", + "hgnetv2_b6.ssld_stage2_ft_in1k", +] +keras_model_classes = [ + hgnet.HGNetTiny, + hgnet.HGNetSmall, + hgnet.HGNetBase, + hgnet.HGNetV2B0, + hgnet.HGNetV2B1, + hgnet.HGNetV2B2, + hgnet.HGNetV2B3, + hgnet.HGNetV2B4, + hgnet.HGNetV2B5, + hgnet.HGNetV2B6, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + weights=None, + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # stem + if "stem.stem" not in torch_name: + # HGNet + torch_name = torch_name.replace("stem", "stem.stem") + # conv2d + torch_name = torch_name.replace("dwconv2d.kernel", "conv.weight") + torch_name = torch_name.replace("conv2d.kernel", "conv.weight") + # head + torch_name = torch_name.replace("last.conv", "last_conv") + torch_name = torch_name.replace("classifier", "head.fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}")