Skip to content

Commit

Permalink
Fix import bug and refactor model registry (#12)
Browse files Browse the repository at this point in the history
* Fix import error

* Rename `list_models` and add predicitons to features
  • Loading branch information
james77777778 authored Jan 17, 2024
1 parent d7804ac commit 6e37c90
Show file tree
Hide file tree
Showing 24 changed files with 187 additions and 164 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import kimm
print(kimm.list_models())

# Specify the name and other arguments to filter the result
print(kimm.list_models("efficientnet", has_pretrained=True)) # fuzzy search
print(kimm.list_models("efficientnet", weights="imagenet")) # fuzzy search

# Initialize the model with pretrained weights
model = kimm.models.EfficientNetV2B0(weights="imagenet")
Expand All @@ -36,7 +36,7 @@ print(y.shape)

# Initialize the model as a feature extractor with pretrained weights
model = kimm.models.EfficientNetV2B0(
as_feature_extractor=True, weights="imagenet"
feature_extractor=True, weights="imagenet"
)

# Extract features for downstream tasks
Expand Down
3 changes: 3 additions & 0 deletions kimm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from kimm.models.base_model import BaseModel
from kimm.models.densenet import * # noqa:F403
from kimm.models.efficientnet import * # noqa:F403
from kimm.models.ghostnet import * # noqa:F403
from kimm.models.inception_v3 import * # noqa:F403
from kimm.models.mobilenet_v2 import * # noqa:F403
from kimm.models.mobilenet_v3 import * # noqa:F403
from kimm.models.mobilevit import * # noqa:F403
from kimm.models.resnet import * # noqa:F403
from kimm.models.vision_transformer import * # noqa:F403
11 changes: 7 additions & 4 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ def __init__(
feature_keys: typing.Optional[typing.List[str]] = None,
**kwargs,
):
self.as_feature_extractor = kwargs.pop("as_feature_extractor", False)
self.feature_extractor = kwargs.pop("feature_extractor", False)
self.feature_keys = feature_keys
if self.as_feature_extractor:
if self.feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when "
f"`as_feature_extractor=True`. Got features={features}"
f"`feature_extractor=True`. Received features={features}"
)
if self.feature_keys is None:
self.feature_keys = list(features.keys())
Expand All @@ -35,6 +35,9 @@ def __init__(
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
# add outputs
if backend.is_keras_tensor(outputs):
filtered_features["TOP"] = outputs
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
else:
del features
Expand Down Expand Up @@ -139,7 +142,7 @@ def get_config(self):
"name": self.name,
"trainable": self.trainable,
# feature extractor
"as_feature_extractor": self.as_feature_extractor,
"feature_extractor": self.feature_extractor,
"feature_keys": self.feature_keys,
# common
"input_shape": self.input_shape[1:],
Expand Down
14 changes: 7 additions & 7 deletions kimm/models/base_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,31 @@ def test_feature_extractor(self):
x = random.uniform([1, 224, 224, 3])

# availiable_feature_keys
self.assertEqual(
SampleModel.available_feature_keys(),
self.assertContainsSubset(
["S2", "S4", "S8", "S16", "S32"],
SampleModel.available_feature_keys(),
)

# as_feature_extractor=False
# feature_extractor=False
model = SampleModel()

y = model(x, training=False)

self.assertNotIsInstance(y, dict)
self.assertEqual(list(y.shape), [1, 7, 7, 3])

# as_feature_extractor=True
model = SampleModel(as_feature_extractor=True)
# feature_extractor=True
model = SampleModel(feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertEqual(list(y["S2"].shape), [1, 112, 112, 3])
self.assertEqual(list(y["S32"].shape), [1, 7, 7, 3])

# as_feature_extractor=True with feature_keys
# feature_extractor=True with feature_keys
model = SampleModel(
as_feature_extractor=True, feature_keys=["S2", "S16", "S32"]
feature_extractor=True, feature_keys=["S2", "S16", "S32"]
)

y = model(x, training=False)
Expand Down
8 changes: 4 additions & 4 deletions kimm/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(
)


add_model_to_registry(DenseNet121, True)
add_model_to_registry(DenseNet161, True)
add_model_to_registry(DenseNet169, True)
add_model_to_registry(DenseNet201, True)
add_model_to_registry(DenseNet121, "imagenet")
add_model_to_registry(DenseNet161, "imagenet")
add_model_to_registry(DenseNet169, "imagenet")
add_model_to_registry(DenseNet201, "imagenet")
9 changes: 4 additions & 5 deletions kimm/models/densenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ def test_densenet_base(self, model_class):
@parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)])
def test_densenet_feature_extractor(self, model_class):
x = random.uniform([1, 224, 224, 3]) * 255.0
model = model_class(
input_shape=[224, 224, 3], as_feature_extractor=True
)
model = model_class(input_shape=[224, 224, 3], feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
self.assertEqual(list(y["STEM_S4"].shape), [1, 56, 56, 64])
self.assertEqual(list(y["BLOCK0_S8"].shape), [1, 28, 28, 128])
Expand Down
52 changes: 26 additions & 26 deletions kimm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,29 +1514,29 @@ def __init__(
)


add_model_to_registry(EfficientNetB0, True)
add_model_to_registry(EfficientNetB1, True)
add_model_to_registry(EfficientNetB2, True)
add_model_to_registry(EfficientNetB3, True)
add_model_to_registry(EfficientNetB4, True)
add_model_to_registry(EfficientNetB5, True)
add_model_to_registry(EfficientNetB6, True)
add_model_to_registry(EfficientNetB7, True)
add_model_to_registry(EfficientNetLiteB0, True)
add_model_to_registry(EfficientNetLiteB1, True)
add_model_to_registry(EfficientNetLiteB2, True)
add_model_to_registry(EfficientNetLiteB3, True)
add_model_to_registry(EfficientNetLiteB4, True)
add_model_to_registry(EfficientNetV2S, True)
add_model_to_registry(EfficientNetV2M, True)
add_model_to_registry(EfficientNetV2L, True)
add_model_to_registry(EfficientNetV2XL, True)
add_model_to_registry(EfficientNetV2B0, True)
add_model_to_registry(EfficientNetV2B1, True)
add_model_to_registry(EfficientNetV2B2, True)
add_model_to_registry(EfficientNetV2B3, True)
add_model_to_registry(TinyNetA, True)
add_model_to_registry(TinyNetB, True)
add_model_to_registry(TinyNetC, True)
add_model_to_registry(TinyNetD, True)
add_model_to_registry(TinyNetE, True)
add_model_to_registry(EfficientNetB0, "imagenet")
add_model_to_registry(EfficientNetB1, "imagenet")
add_model_to_registry(EfficientNetB2, "imagenet")
add_model_to_registry(EfficientNetB3, "imagenet")
add_model_to_registry(EfficientNetB4, "imagenet")
add_model_to_registry(EfficientNetB5, "imagenet")
add_model_to_registry(EfficientNetB6, "imagenet")
add_model_to_registry(EfficientNetB7, "imagenet")
add_model_to_registry(EfficientNetLiteB0, "imagenet")
add_model_to_registry(EfficientNetLiteB1, "imagenet")
add_model_to_registry(EfficientNetLiteB2, "imagenet")
add_model_to_registry(EfficientNetLiteB3, "imagenet")
add_model_to_registry(EfficientNetLiteB4, "imagenet")
add_model_to_registry(EfficientNetV2S, "imagenet")
add_model_to_registry(EfficientNetV2M, "imagenet")
add_model_to_registry(EfficientNetV2L, "imagenet")
add_model_to_registry(EfficientNetV2XL, "imagenet")
add_model_to_registry(EfficientNetV2B0, "imagenet")
add_model_to_registry(EfficientNetV2B1, "imagenet")
add_model_to_registry(EfficientNetV2B2, "imagenet")
add_model_to_registry(EfficientNetV2B3, "imagenet")
add_model_to_registry(TinyNetA, "imagenet")
add_model_to_registry(TinyNetB, "imagenet")
add_model_to_registry(TinyNetC, "imagenet")
add_model_to_registry(TinyNetD, "imagenet")
add_model_to_registry(TinyNetE, "imagenet")
18 changes: 8 additions & 10 deletions kimm/models/efficientnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ def test_efficentnet_feature_extractor(
self, model_class, width, fix_stem_channels
):
x = random.uniform([1, 224, 224, 3]) * 255.0
model = model_class(
input_shape=[224, 224, 3], as_feature_extractor=True
)
model = model_class(input_shape=[224, 224, 3], feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
if fix_stem_channels:
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 32])
Expand Down Expand Up @@ -86,15 +85,14 @@ def test_efficentnet_feature_extractor(
)
def test_efficentnet_v2_feature_extractor(self, model_class, width):
x = random.uniform([1, 224, 224, 3]) * 255.0
model = model_class(
input_shape=[224, 224, 3], as_feature_extractor=True
)
model = model_class(input_shape=[224, 224, 3], feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
if "EfficientNetV2S" in model_class.__name__:
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 24])
Expand Down
12 changes: 6 additions & 6 deletions kimm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@ def __init__(
)


add_model_to_registry(GhostNet050, False)
add_model_to_registry(GhostNet100, True)
add_model_to_registry(GhostNet130, True)
add_model_to_registry(GhostNet100V2, True)
add_model_to_registry(GhostNet130V2, True)
add_model_to_registry(GhostNet160V2, True)
add_model_to_registry(GhostNet050)
add_model_to_registry(GhostNet100, "imagenet")
add_model_to_registry(GhostNet130, "imagenet")
add_model_to_registry(GhostNet100V2, "imagenet")
add_model_to_registry(GhostNet130V2, "imagenet")
add_model_to_registry(GhostNet160V2, "imagenet")
14 changes: 8 additions & 6 deletions kimm/models/ghostnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ def test_ghostnet_base(self, model_class):
@parameterized.named_parameters([(GhostNet100.__name__, GhostNet100)])
def test_ghostnet_feature_extractor(self, model_class):
x = random.uniform([1, 224, 224, 3]) * 255.0
model = model_class(as_feature_extractor=True)
model = model_class(feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 16])
self.assertEqual(list(y["BLOCK1_S4"].shape), [1, 56, 56, 24])
Expand All @@ -49,13 +50,14 @@ def test_ghostnetv2_base(self, model_class):
@parameterized.named_parameters([(GhostNet100V2.__name__, GhostNet100V2)])
def test_ghostnetv2_feature_extractor(self, model_class):
x = random.uniform([1, 224, 224, 3]) * 255.0
model = model_class(as_feature_extractor=True)
model = model_class(feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 16])
self.assertEqual(list(y["BLOCK1_S4"].shape), [1, 56, 56, 24])
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,4 +333,4 @@ def __init__(
)


add_model_to_registry(InceptionV3, True)
add_model_to_registry(InceptionV3, "imagenet")
7 changes: 4 additions & 3 deletions kimm/models/inception_v3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ def test_inception_v3_base(self, model_class):
@parameterized.named_parameters([(InceptionV3.__name__, InceptionV3)])
def test_inception_v3_feature_extractor(self, model_class):
x = random.uniform([1, 299, 299, 3]) * 255.0
model = model_class(as_feature_extractor=True)
model = model_class(feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
self.assertEqual(list(y["STEM_S2"].shape), [1, 147, 147, 64])
self.assertEqual(list(y["BLOCK0_S4"].shape), [1, 71, 71, 192])
Expand Down
10 changes: 5 additions & 5 deletions kimm/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def __init__(
)


add_model_to_registry(MobileNet050V2, True)
add_model_to_registry(MobileNet100V2, True)
add_model_to_registry(MobileNet110V2, True)
add_model_to_registry(MobileNet120V2, True)
add_model_to_registry(MobileNet140V2, True)
add_model_to_registry(MobileNet050V2, "imagenet")
add_model_to_registry(MobileNet100V2, "imagenet")
add_model_to_registry(MobileNet110V2, "imagenet")
add_model_to_registry(MobileNet120V2, "imagenet")
add_model_to_registry(MobileNet140V2, "imagenet")
7 changes: 4 additions & 3 deletions kimm/models/mobilenet_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ def test_mobilenet_v2_base(self, model_class):
)
def test_mobilenet_v2_feature_extractor(self, model_class, width):
x = random.uniform([1, 224, 224, 3]) * 255.0
model = model_class(as_feature_extractor=True)
model = model_class(feature_extractor=True)

y = model(x, training=False)

self.assertIsInstance(y, dict)
self.assertAllEqual(
list(y.keys()), model_class.available_feature_keys()
self.assertContainsSubset(
model_class.available_feature_keys(),
list(y.keys()),
)
self.assertEqual(
list(y["STEM_S2"].shape), [1, 112, 112, make_divisible(32 * width)]
Expand Down
22 changes: 11 additions & 11 deletions kimm/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,14 +819,14 @@ def available_feature_keys():
return feature_keys


add_model_to_registry(MobileNet050V3Small, True)
add_model_to_registry(MobileNet075V3Small, True)
add_model_to_registry(MobileNet100V3Small, True)
add_model_to_registry(MobileNet100V3SmallMinimal, True)
add_model_to_registry(MobileNet100V3Large, True)
add_model_to_registry(MobileNet100V3LargeMinimal, True)
add_model_to_registry(LCNet035, False)
add_model_to_registry(LCNet050, True)
add_model_to_registry(LCNet075, True)
add_model_to_registry(LCNet100, True)
add_model_to_registry(LCNet150, False)
add_model_to_registry(MobileNet050V3Small, "imagenet")
add_model_to_registry(MobileNet075V3Small, "imagenet")
add_model_to_registry(MobileNet100V3Small, "imagenet")
add_model_to_registry(MobileNet100V3SmallMinimal, "imagenet")
add_model_to_registry(MobileNet100V3Large, "imagenet")
add_model_to_registry(MobileNet100V3LargeMinimal, "imagenet")
add_model_to_registry(LCNet035)
add_model_to_registry(LCNet050, "imagenet")
add_model_to_registry(LCNet075, "imagenet")
add_model_to_registry(LCNet100, "imagenet")
add_model_to_registry(LCNet150)
Loading

0 comments on commit 6e37c90

Please sign in to comment.