Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileOne #36

Merged
merged 7 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io
|LCNet|[arXiv 2021](https://arxiv.org/abs/2109.15099)|`timm`|`kimm.models.LCNet*`|
|MobileNetV2|[CVPR 2018](https://arxiv.org/abs/1801.04381)|`timm`|`kimm.models.MobileNetV2*`|
|MobileNetV3|[ICCV 2019](https://arxiv.org/abs/1905.02244)|`timm`|`kimm.models.MobileNetV3*`|
|MobileOne|[CVPR 2023](https://arxiv.org/abs/2206.04040)|`timm`|`kimm.models.MobileOne*`|
|MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`|
|MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`|
|RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`|
Expand Down
2 changes: 1 addition & 1 deletion kimm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from kimm import models # force to add models to the registry
from kimm.utils.model_registry import list_models

__version__ = "0.1.6"
__version__ = "0.1.7"
1 change: 1 addition & 0 deletions kimm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kimm.layers.attention import Attention
from kimm.layers.layer_scale import LayerScale
from kimm.layers.mobile_one_conv2d import MobileOneConv2D
from kimm.layers.position_embedding import PositionEmbedding
from kimm.layers.rep_conv2d import RepConv2D
14 changes: 6 additions & 8 deletions kimm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(
use_qk_norm: bool = False,
attention_dropout_rate: float = 0.0,
projection_dropout_rate: float = 0.0,
name: str = "attention",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -25,20 +24,19 @@ def __init__(
self.use_qk_norm = use_qk_norm
self.attention_dropout_rate = attention_dropout_rate
self.projection_dropout_rate = projection_dropout_rate
self.name = name

self.qkv = layers.Dense(
hidden_dim * 3,
use_bias=use_qkv_bias,
dtype=self.dtype_policy,
name=f"{name}_qkv",
name=f"{self.name}_qkv",
)
if use_qk_norm:
self.q_norm = layers.LayerNormalization(
dtype=self.dtype_policy, name=f"{name}_q_norm"
dtype=self.dtype_policy, name=f"{self.name}_q_norm"
)
self.k_norm = layers.LayerNormalization(
dtype=self.dtype_policy, name=f"{name}_k_norm"
dtype=self.dtype_policy, name=f"{self.name}_k_norm"
)
else:
self.q_norm = layers.Identity(dtype=self.dtype_policy)
Expand All @@ -47,15 +45,15 @@ def __init__(
self.attention_dropout = layers.Dropout(
attention_dropout_rate,
dtype=self.dtype_policy,
name=f"{name}_attn_drop",
name=f"{self.name}_attn_drop",
)
self.projection = layers.Dense(
hidden_dim, dtype=self.dtype_policy, name=f"{name}_proj"
hidden_dim, dtype=self.dtype_policy, name=f"{self.name}_proj"
)
self.projection_dropout = layers.Dropout(
projection_dropout_rate,
dtype=self.dtype_policy,
name=f"{name}_proj_drop",
name=f"{self.name}_proj_drop",
)

def build(self, input_shape):
Expand Down
2 changes: 0 additions & 2 deletions kimm/layers/layer_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@ def __init__(
self,
axis: int = -1,
initializer: Initializer = initializers.Constant(1e-5),
name: str = "layer_scale",
**kwargs,
):
super().__init__(**kwargs)
self.axis = axis
self.initializer = initializer
self.name = name

def build(self, input_shape):
if isinstance(self.axis, list):
Expand Down
Loading