Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve test coverage and refactor BaseModel #26

Merged
merged 4 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pip install keras kimm
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14WxYgVjlwCIO9MwqPYW-dskbTL2UHsVN?usp=sharing)

```python
import cv2
import keras
from keras import ops
from keras import utils
from keras.applications.imagenet_utils import decode_predictions

import kimm
Expand All @@ -43,15 +43,15 @@ print(kimm.list_models())
print(kimm.list_models("efficientnet", weights="imagenet")) # fuzzy search

# Initialize the model with pretrained weights
model = kimm.models.EfficientNetV2B0()
image_size = model._default_size
model = kimm.models.VisionTransformerTiny16()
image_size = (model._default_size, model._default_size)

# Load an image as the model input
image_path = keras.utils.get_file(
"african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
)
image = cv2.imread(image_path)
image = cv2.resize(image, (image_size, image_size))
image = utils.load_img(image_path, target_size=image_size)
image = utils.img_to_array(image)
x = ops.convert_to_tensor(image)
x = ops.expand_dims(x, axis=0)

Expand All @@ -62,9 +62,9 @@ print("Predicted:", decode_predictions(preds, top=3)[0])

```bash
['ConvMixer1024D20', 'ConvMixer1536D20', 'ConvMixer736D32', 'ConvNeXtAtto', ...]
['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', ...]
1/1 ━━━━━━━━━━━━━━━━━━━━ 11s 11s/step
Predicted: [('n02504458', 'African_elephant', 0.90578836), ('n01871265', 'tusker', 0.024864597), ('n02504013', 'Indian_elephant', 0.01161992)]
['VisionTransformerBase16', 'VisionTransformerBase32', 'VisionTransformerSmall16', ...]
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
Predicted: [('n02504458', 'African_elephant', 0.6895825), ('n01871265', 'tusker', 0.17934209), ('n02504013', 'Indian_elephant', 0.12927249)]
```

### An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset
Expand Down
2 changes: 1 addition & 1 deletion kimm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kimm import models # force to add models to the registry
from kimm.utils.model_registry import list_models

__version__ = "0.1.2"
__version__ = "0.1.3"
32 changes: 22 additions & 10 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import abc
import pathlib
import typing
import urllib.parse
Expand All @@ -12,6 +11,12 @@


class BaseModel(models.Model):
default_origin = (
"https://github.com/james77777778/kimm/releases/download/0.1.0/"
)
available_feature_keys = []
available_weights = []

def __init__(
self,
inputs,
Expand Down Expand Up @@ -183,12 +188,6 @@ def load_pretrained_weights(self, weights_url: typing.Optional[str] = None):
)
self.load_weights(weights_path)

@staticmethod
@abc.abstractmethod
def available_feature_keys():
# TODO: add docstring
raise NotImplementedError

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to BaseModel.
Expand All @@ -215,6 +214,19 @@ def get_config(self):
def fix_config(self, config: typing.Dict):
return config

@property
def default_origin(self):
return "https://github.com/james77777778/kimm/releases/download/0.1.0/"
def get_weights_url(self, weights):
if weights is None:
return None

for _weights, _origin, _file_name in self.available_weights:
if weights == _weights:
return f"{_origin}/{_file_name}"

# Failed to find the weights
_available_weights_name = [
_weights for _weights, _ in self.available_weights
]
raise ValueError(
f"Available weights are {_available_weights_name}. "
f"Received weights={weights}"
)
60 changes: 29 additions & 31 deletions kimm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
activation: str = "relu",
**kwargs,
):
kwargs["weights_url"] = self.get_weights_url(kwargs["weights"])

input_tensor = kwargs.pop("input_tensor", None)
self.set_properties(kwargs)
inputs = self.determine_input_tensor(
Expand Down Expand Up @@ -100,10 +102,6 @@ def __init__(
self.kernel_size = kernel_size
self.activation = activation

@staticmethod
def available_feature_keys():
raise NotImplementedError

def get_config(self):
config = super().get_config()
config.update(
Expand Down Expand Up @@ -136,6 +134,15 @@ def fix_config(self, config):


class ConvMixer736D32(ConvMixer):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]]
available_weights = [
(
"imagenet",
ConvMixer.default_origin,
"convmixer736d32_convmixer_768_32.in1k.keras",
)
]

def __init__(
self,
input_tensor: keras.KerasTensor = None,
Expand All @@ -151,9 +158,6 @@ def __init__(
**kwargs,
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
file_name = "convmixer736d32_convmixer_768_32.in1k.keras"
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
32,
768,
Expand All @@ -173,14 +177,17 @@ def __init__(
**kwargs,
)

@staticmethod
def available_feature_keys():
feature_keys = ["STEM"]
feature_keys.extend([f"BLOCK{i}" for i in range(32)])
return feature_keys


class ConvMixer1024D20(ConvMixer):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
available_weights = [
(
"imagenet",
ConvMixer.default_origin,
"convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras",
)
]

def __init__(
self,
input_tensor: keras.KerasTensor = None,
Expand All @@ -196,9 +203,6 @@ def __init__(
**kwargs,
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
file_name = "convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras"
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
20,
1024,
Expand All @@ -218,14 +222,17 @@ def __init__(
**kwargs,
)

@staticmethod
def available_feature_keys():
feature_keys = ["STEM"]
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
return feature_keys


class ConvMixer1536D20(ConvMixer):
available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]]
available_weights = [
(
"imagenet",
ConvMixer.default_origin,
"convmixer1536d20_convmixer_1536_20.in1k.keras",
)
]

def __init__(
self,
input_tensor: keras.KerasTensor = None,
Expand All @@ -241,9 +248,6 @@ def __init__(
**kwargs,
):
kwargs = self.fix_config(kwargs)
if weights == "imagenet":
file_name = "convmixer1536d20_convmixer_1536_20.in1k.keras"
kwargs["weights_url"] = f"{self.default_origin}/{file_name}"
super().__init__(
20,
1536,
Expand All @@ -263,12 +267,6 @@ def __init__(
**kwargs,
)

@staticmethod
def available_feature_keys():
feature_keys = ["STEM"]
feature_keys.extend([f"BLOCK{i}" for i in range(20)])
return feature_keys


add_model_to_registry(ConvMixer736D32, "imagenet")
add_model_to_registry(ConvMixer1024D20, "imagenet")
Expand Down
Loading