Skip to content

Commit

Permalink
Improve test coverage and refactor BaseModel (#26)
Browse files Browse the repository at this point in the history
* Update `available_feature_keys` and `available_weights`

* Fix serialization

* Update version

* Update `README`
  • Loading branch information
james77777778 authored Jan 21, 2024
1 parent c608f1c commit 4506330
Show file tree
Hide file tree
Showing 20 changed files with 1,083 additions and 667 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pip install keras kimm
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14WxYgVjlwCIO9MwqPYW-dskbTL2UHsVN?usp=sharing)

```python
import cv2
import keras
from keras import ops
from keras import utils
from keras.applications.imagenet_utils import decode_predictions

import kimm
Expand All @@ -43,15 +43,15 @@ print(kimm.list_models())
print(kimm.list_models("efficientnet", weights="imagenet")) # fuzzy search

# Initialize the model with pretrained weights
model = kimm.models.EfficientNetV2B0()
image_size = model._default_size
model = kimm.models.VisionTransformerTiny16()
image_size = (model._default_size, model._default_size)

# Load an image as the model input
image_path = keras.utils.get_file(
"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
)
image = cv2.imread(image_path)
image = cv2.resize(image, (image_size, image_size))
image = utils.load_img(image_path, target_size=image_size)
image = utils.img_to_array(image)
x = ops.convert_to_tensor(image)
x = ops.expand_dims(x, axis=0)

Expand All @@ -62,9 +62,9 @@ print("Predicted:", decode_predictions(preds, top=3)[0])

```bash
['ConvMixer1024D20', 'ConvMixer1536D20', 'ConvMixer736D32', 'ConvNeXtAtto', ...]
['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', ...]
1/1 ━━━━━━━━━━━━━━━━━━━━ 11s 11s/step
Predicted: [('n02504458', 'African_elephant', 0.90578836), ('n01871265', 'tusker', 0.024864597), ('n02504013', 'Indian_elephant', 0.01161992)]
['VisionTransformerBase16', 'VisionTransformerBase32', 'VisionTransformerSmall16', ...]
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
Predicted: [('n02504458', 'African_elephant', 0.6895825), ('n01871265', 'tusker', 0.17934209), ('n02504013', 'Indian_elephant', 0.12927249)]
```

### An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset
Expand Down
2 changes: 1 addition & 1 deletion kimm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kimm import models # force to add models to the registry
from kimm.utils.model_registry import list_models

__version__ = "0.1.2"
__version__ = "0.1.3"
32 changes: 22 additions & 10 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import pathlib
import typing
import urllib.parse
Expand All @@ -12,6 +11,12 @@


class BaseModel(models.Model):
default_origin = (
"https://github.com/james77777778/kimm/releases/download/0.1.0/"
)
available_feature_keys = []
available_weights = []

def __init__(
self,
inputs,
Expand Down Expand Up @@ -183,12 +188,6 @@ def load_pretrained_weights(self, weights_url: typing.Optional[str] = None):
)
self.load_weights(weights_path)

@staticmethod
@abc.abstractmethod
def available_feature_keys():
# TODO: add docstring
raise NotImplementedError

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to BaseModel.
Expand All @@ -215,6 +214,19 @@ def get_config(self):
def fix_config(self, config: typing.Dict):
return config

@property
def default_origin(self):
return "https://github.com/james77777778/kimm/releases/download/0.1.0/"
def get_weights_url(self, weights):
if weights is None:
return None

for _weights, _origin, _file_name in self.available_weights:
if weights == _weights:
return f"{_origin}/{_file_name}"

# Failed to find the weights
_available_weights_name = [
_weights for _weights, _ in self.available_weights
]
raise ValueError(
f"Available weights are {_available_weights_name}. "
f"Received weights={weights}"
)
60 changes: 29 additions & 31 deletions kimm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
activation: str = "relu",
**kwargs,
):
kwargs["weights_url"] = self.get_weights_url(kwargs["weights"])

input_tensor = kwargs.pop("input_tensor", None)
self.set_properties(kwargs)
inputs = self.determine_input_tensor(
Expand Down Expand Up @@ -100,10 +102,6 @@ def __init__(
self.kernel_size = kernel_size
self.activation = activation

@staticmethod
def available_feature_keys():
raise NotImplementedError

def get_config(self):
config = super().get_config()
config.update(
Expand Down Expand Up @@ -136,6 +134,15 @@ def fix_config(self, config):


class ConvMixer736D32(ConvMixer):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]]
available_weights = [
(
"imagenet",
ConvMixer.default_origin,
"convmixer736d32_convmixer_768_32.in1k.keras",
)
]

def __init__(
self,
input_tensor: keras.KerasTensor = None,
Expand All @@ -151,9 +158,6 @@ def __init__(
**kwargs,
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
file_name = "convmixer736d32_convmixer_768_32.in1k.keras"
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
32,
768,
Expand All @@ -173,14 +177,17 @@ def __init__(
**kwargs,
)

@staticmethod
def available_feature_keys():
feature_keys = ["STEM"]
feature_keys.extend([f"BLOCK{i}" for i in range(32)])
return feature_keys


class ConvMixer1024D20(ConvMixer):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
available_weights = [
(
"imagenet",
ConvMixer.default_origin,
"convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras",
)
]

def __init__(
self,
input_tensor: keras.KerasTensor = None,
Expand All @@ -196,9 +203,6 @@ def __init__(
**kwargs,
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
file_name = "convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras"
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
20,
1024,
Expand All @@ -218,14 +222,17 @@ def __init__(
**kwargs,
)

@staticmethod
def available_feature_keys():
feature_keys = ["STEM"]
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
return feature_keys


class ConvMixer1536D20(ConvMixer):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
available_weights = [
(
"imagenet",
ConvMixer.default_origin,
"convmixer1536d20_convmixer_1536_20.in1k.keras",
)
]

def __init__(
self,
input_tensor: keras.KerasTensor = None,
Expand All @@ -241,9 +248,6 @@ def __init__(
**kwargs,
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
file_name = "convmixer1536d20_convmixer_1536_20.in1k.keras"
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
20,
1536,
Expand All @@ -263,12 +267,6 @@ def __init__(
**kwargs,
)

@staticmethod
def available_feature_keys():
feature_keys = ["STEM"]
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
return feature_keys


add_model_to_registry(ConvMixer736D32, "imagenet")
add_model_to_registry(ConvMixer1024D20, "imagenet")
Expand Down
Loading

0 comments on commit 4506330

Please sign in to comment.