Skip to content

Commit

Permalink
Add HGNet and HGNetV2 (#41)
Browse files Browse the repository at this point in the history
* Add HGNet and HGNetV2

* Update README and requirements.txt

* Update configs

* Fix mixed precision for `LearnableAffine`
  • Loading branch information
james77777778 authored Mar 23, 2024
1 parent 8995b40 commit 195a622
Show file tree
Hide file tree
Showing 14 changed files with 1,224 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io
|EfficientNetV2|[ICML 2021](https://arxiv.org/abs/2104.00298)|`timm`|`kimm.models.EfficientNetV2*`|
|GhostNet|[CVPR 2020](https://arxiv.org/abs/1911.11907)|`timm`|`kimm.models.GhostNet*`|
|GhostNetV2|[NeurIPS 2022](https://arxiv.org/abs/2211.12905)|`timm`|`kimm.models.GhostNetV2*`|
|HGNet||`timm`|`kimm.models.HGNet*`|
|HGNetV2||`timm`|`kimm.models.HGNetV2*`|
|InceptionNeXt|[arXiv 2023](https://arxiv.org/abs/2303.16900)|`timm`|`kimm.models.InceptionNeXt*`|
|InceptionV3|[CVPR 2016](https://arxiv.org/abs/1512.00567)|`timm`|`kimm.models.InceptionV3`|
|LCNet|[arXiv 2021](https://arxiv.org/abs/2109.15099)|`timm`|`kimm.models.LCNet*`|
Expand Down
3 changes: 2 additions & 1 deletion kimm/blocks/base_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def apply_conv2d_block(
if strides > 1:
padding = "valid"
x = layers.ZeroPadding2D(
(kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad"
((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2),
name=f"{name}_pad",
)(x)

if not use_depthwise:
Expand Down
1 change: 1 addition & 0 deletions kimm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from kimm.layers.attention import Attention
from kimm.layers.layer_scale import LayerScale
from kimm.layers.learnable_affine import LearnableAffine
from kimm.layers.mobile_one_conv2d import MobileOneConv2D
from kimm.layers.position_embedding import PositionEmbedding
from kimm.layers.rep_conv2d import RepConv2D
50 changes: 50 additions & 0 deletions kimm/layers/learnable_affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import keras
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class LearnableAffine(layers.Layer):
def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs):
super().__init__(**kwargs)
if isinstance(scale_value, int):
raise ValueError(
f"scale_value must be a integer. Received: {scale_value}"
)
if isinstance(bias_value, int):
raise ValueError(
f"bias_value must be a integer. Received: {bias_value}"
)
self.scale_value = scale_value
self.bias_value = bias_value

def build(self, input_shape):
self.scale = self.add_weight(
shape=(1,),
initializer=lambda shape, dtype: ops.cast(self.scale_value, dtype),
trainable=True,
name="scale",
)
self.bias = self.add_weight(
shape=(1,),
initializer=lambda shape, dtype: ops.cast(self.bias_value, dtype),
trainable=True,
name="bias",
)
self.built = True

def call(self, inputs, training=None, mask=None):
scale = ops.cast(self.scale, self.compute_dtype)
bias = ops.cast(self.bias, self.compute_dtype)
return ops.add(ops.multiply(inputs, scale), bias)

def get_config(self):
config = super().get_config()
config.update(
{
"scale_value": self.scale_value,
"bias_value": self.bias_value,
"name": self.name,
}
)
return config
20 changes: 20 additions & 0 deletions kimm/layers/learnable_affine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
from absl.testing import parameterized
from keras.src import testing

from kimm.layers.learnable_affine import LearnableAffine


class LearnableAffineTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer_scale_basic(self):
self.run_layer_test(
LearnableAffine,
init_kwargs={"scale_value": 1.0, "bias_value": 0.0},
input_shape=(1, 10),
expected_output_shape=(1, 10),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_losses=0,
supports_masking=False,
)
1 change: 1 addition & 0 deletions kimm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from kimm.models.densenet import * # noqa:F403
from kimm.models.efficientnet import * # noqa:F403
from kimm.models.ghostnet import * # noqa:F403
from kimm.models.hgnet import * # noqa:F403
from kimm.models.inception_next import * # noqa:F403
from kimm.models.inception_v3 import * # noqa:F403
from kimm.models.mobilenet_v2 import * # noqa:F403
Expand Down
Loading

0 comments on commit 195a622

Please sign in to comment.