Skip to content

Commit

Permalink
Add ConvNeXt and refactor BaseModel (#16)
Browse files Browse the repository at this point in the history
* Add `ConvNeXt`

* Update `requirements.txt`

* Refactor `BaseModel` to reduce redundant code
  • Loading branch information
james77777778 authored Jan 19, 2024
1 parent face1a0 commit 03d84b4
Show file tree
Hide file tree
Showing 26 changed files with 1,046 additions and 450 deletions.
15 changes: 13 additions & 2 deletions kimm/blocks/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,29 @@ def apply_mlp_block(
normalization=None,
use_bias=True,
dropout_rate=0.0,
use_conv_mlp=False,
name="mlp_block",
):
input_dim = inputs.shape[-1]
output_dim = output_dim or input_dim

x = inputs
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
if use_conv_mlp:
x = layers.Conv2D(
hidden_dim, 1, use_bias=use_bias, name=f"{name}_fc1_conv2d"
)(x)
else:
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
x = layers.Activation(activation, name=f"{name}_act")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
if normalization is not None:
x = normalization(name=f"{name}_norm")(x)
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
if use_conv_mlp:
x = layers.Conv2D(
output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d"
)(x)
else:
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x)
return x

Expand Down
2 changes: 2 additions & 0 deletions kimm/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class Attention(layers.Layer):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions kimm/layers/layer_scale.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import keras
from keras import initializers
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class LayerScale(layers.Layer):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions kimm/layers/position_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class PositionEmbedding(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down
1 change: 1 addition & 0 deletions kimm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from kimm.models.base_model import BaseModel
from kimm.models.convmixer import * # noqa:F403
from kimm.models.convnext import * # noqa:F403
from kimm.models.densenet import * # noqa:F403
from kimm.models.efficientnet import * # noqa:F403
from kimm.models.ghostnet import * # noqa:F403
Expand Down
171 changes: 111 additions & 60 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import abc
import pathlib
import typing
import urllib.parse

from keras import KerasTensor
from keras import backend
from keras import layers
from keras import models
from keras import utils
from keras.src.applications import imagenet_utils


Expand All @@ -14,53 +17,79 @@ def __init__(
inputs,
outputs,
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
feature_keys: typing.Optional[typing.List[str]] = None,
**kwargs,
):
self.feature_extractor = kwargs.pop("feature_extractor", False)
self.feature_keys = feature_keys
if self.feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when "
f"`feature_extractor=True`. Received features={features}"
)
if self.feature_keys is None:
self.feature_keys = list(features.keys())
filtered_features = {}
for k in self.feature_keys:
if k not in features:
raise KeyError(
f"'{k}' is not a key of `features`. Available keys "
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
# add outputs
if backend.is_keras_tensor(outputs):
filtered_features["TOP"] = outputs
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
else:
if not hasattr(self, "_feature_extractor"):
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
else:
if not hasattr(self, "_feature_keys"):
raise AttributeError(
"`self._feature_keys` must be set when initializing "
"BaseModel"
)
if self._feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when `feature_extractor=True`. "
f"Received features={features}"
)
if self._feature_keys is None:
self._feature_keys = list(features.keys())
filtered_features = {}
for k in self._feature_keys:
if k not in features:
raise KeyError(
f"'{k}' is not a key of `features`. Available keys "
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
# Add outputs
if backend.is_keras_tensor(outputs):
filtered_features["TOP"] = outputs
super().__init__(
inputs=inputs, outputs=filtered_features, **kwargs
)
else:
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)

if hasattr(self, "_weights_url"):
self.load_pretrained_weights(self._weights_url)

def parse_kwargs(
def set_properties(
self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224
):
result = {
"input_tensor": kwargs.pop("input_tensor", None),
"input_shape": kwargs.pop("input_shape", None),
"include_preprocessing": kwargs.pop("include_preprocessing", True),
"include_top": kwargs.pop("include_top", True),
"pooling": kwargs.pop("pooling", None),
"dropout_rate": kwargs.pop("dropout_rate", 0.0),
"classes": kwargs.pop("classes", 1000),
"classifier_activation": kwargs.pop(
"classifier_activation", "softmax"
),
"weights": kwargs.pop("weights", "imagenet"),
"default_size": kwargs.pop("default_size", default_size),
}
return result
"""Must be called in the initilization of the class.
This method will add following common properties to the model object:
- input_shape
- include_preprocessing
- include_top
- pooling
- dropout_rate
- classes
- classifier_activation
- _weights
- weights_url
- default_size
"""
self._input_shape = kwargs.pop("input_shape", None)
self._include_preprocessing = kwargs.pop("include_preprocessing", True)
self._include_top = kwargs.pop("include_top", True)
self._pooling = kwargs.pop("pooling", None)
self._dropout_rate = kwargs.pop("dropout_rate", 0.0)
self._classes = kwargs.pop("classes", 1000)
self._classifier_activation = kwargs.pop(
"classifier_activation", "softmax"
)
self._weights = kwargs.pop("weights", None)
self._weights_url = kwargs.pop("weights_url", None)
self._default_size = kwargs.pop("default_size", default_size)
# feature extractor
self._feature_extractor = kwargs.pop("feature_extractor", False)
self._feature_keys = kwargs.pop("feature_keys", None)
print("self._feature_keys", self._feature_keys)

def determine_input_tensor(
self,
Expand All @@ -87,10 +116,12 @@ def determine_input_tensor(
if not backend.is_keras_tensor(input_tensor):
x = layers.Input(tensor=input_tensor, shape=input_shape)
else:
x = input_tensor
x = utils.get_source_inputs(input_tensor)
return x

def build_preprocessing(self, inputs, mode="imagenet"):
if self._include_preprocessing is False:
return inputs
if mode == "imagenet":
# [0, 255] to [0, 1] and apply ImageNet mean and variance
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
Expand Down Expand Up @@ -118,15 +149,30 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate):
)(x)
return x

def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]):
self.include_preprocessing = parsed_kwargs["include_preprocessing"]
self.include_top = parsed_kwargs["include_top"]
self.pooling = parsed_kwargs["pooling"]
self.dropout_rate = parsed_kwargs["dropout_rate"]
self.classes = parsed_kwargs["classes"]
self.classifier_activation = parsed_kwargs["classifier_activation"]
# `self.weights` is been used internally
self._weights = parsed_kwargs["weights"]
def build_head(self, inputs):
x = inputs
if self._include_top:
x = self.build_top(
x,
self._classes,
self._classifier_activation,
self._dropout_rate,
)
else:
if self._pooling == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif self._pooling == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
return x

def load_pretrained_weights(self, weights_url: typing.Optional[str] = None):
if weights_url is not None:
result = urllib.parse.urlparse(weights_url)
file_name = pathlib.Path(result.path).name
weights_path = utils.get_file(
file_name, weights_url, cache_subdir="kimm_models"
)
self.load_weights(weights_path)

@staticmethod
@abc.abstractmethod
Expand All @@ -141,20 +187,25 @@ def get_config(self):
# models.Model
"name": self.name,
"trainable": self.trainable,
# feature extractor
"feature_extractor": self.feature_extractor,
"feature_keys": self.feature_keys,
# common
"input_shape": self.input_shape[1:],
"include_preprocessing": self.include_preprocessing,
"include_top": self.include_top,
"pooling": self.pooling,
"dropout_rate": self.dropout_rate,
"classes": self.classes,
"classifier_activation": self.classifier_activation,
# common
"include_preprocessing": self._include_preprocessing,
"include_top": self._include_top,
"pooling": self._pooling,
"dropout_rate": self._dropout_rate,
"classes": self._classes,
"classifier_activation": self._classifier_activation,
"weights": self._weights,
"weights_url": self._weights_url,
# feature extractor
"feature_extractor": self._feature_extractor,
"feature_keys": self._feature_keys,
}
return config

def fix_config(self, config: typing.Dict):
return config

@property
def default_origin(self):
return "https://github.com/james77777778/keras-aug/releases/download/v0.5.0"
1 change: 1 addition & 0 deletions kimm/models/base_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class SampleModel(BaseModel):
def __init__(self, **kwargs):
self.set_properties(kwargs)
inputs = layers.Input(shape=[224, 224, 3])

features = {}
Expand Down
39 changes: 10 additions & 29 deletions kimm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import keras
from keras import layers
from keras import utils

from kimm.models.base_model import BaseModel
from kimm.utils import add_model_to_registry
Expand Down Expand Up @@ -42,6 +41,7 @@ def apply_convmixer_block(
return x


@keras.saving.register_keras_serializable(package="kimm")
class ConvMixer(BaseModel):
def __init__(
self,
Expand All @@ -52,16 +52,16 @@ def __init__(
activation: str = "relu",
**kwargs,
):
parsed_kwargs = self.parse_kwargs(kwargs)
img_input = self.determine_input_tensor(
parsed_kwargs["input_tensor"],
parsed_kwargs["input_shape"],
parsed_kwargs["default_size"],
input_tensor = kwargs.pop("input_tensor", None)
self.set_properties(kwargs)
inputs = self.determine_input_tensor(
input_tensor,
self._input_shape,
self._default_size,
)
x = img_input
x = inputs

if parsed_kwargs["include_preprocessing"]:
x = self.build_preprocessing(x, "imagenet")
x = self.build_preprocessing(x, "imagenet")

# Prepare feature extraction
features = {}
Expand Down Expand Up @@ -89,30 +89,11 @@ def __init__(
features[f"BLOCK{i}"] = x

# Head
if parsed_kwargs["include_top"]:
x = self.build_top(
x,
parsed_kwargs["classes"],
parsed_kwargs["classifier_activation"],
parsed_kwargs["dropout_rate"],
)
else:
if parsed_kwargs["pooling"] == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif parsed_kwargs["pooling"] == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)

# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if parsed_kwargs["input_tensor"] is not None:
inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"])
else:
inputs = img_input
x = self.build_head(x)

super().__init__(inputs=inputs, outputs=x, features=features, **kwargs)

# All references to `self` below this line
self.add_references(parsed_kwargs)
self.depth = depth
self.hidden_channels = hidden_channels
self.patch_size = patch_size
Expand Down
Loading

0 comments on commit 03d84b4

Please sign in to comment.