From 3b973ca9d3ddafa20c2401245c551b5dac824bf9 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jan 2024 15:35:43 +0800 Subject: [PATCH 1/4] Update `available_feature_keys` and `available_weights` --- kimm/models/base_model.py | 32 ++- kimm/models/convmixer.py | 64 +++-- kimm/models/convnext.py | 113 ++++++--- kimm/models/densenet.py | 64 +++-- kimm/models/efficientnet.py | 406 ++++++++++++++++++------------ kimm/models/ghostnet.py | 76 +++--- kimm/models/inception_v3.py | 39 +-- kimm/models/mobilenet_v2.py | 82 +++--- kimm/models/mobilenet_v3.py | 292 +++++++++++---------- kimm/models/mobilevit.py | 52 ++-- kimm/models/models_test.py | 2 +- kimm/models/regnet.py | 304 ++++++++++++++-------- kimm/models/resnet.py | 76 +++--- kimm/models/vgg.py | 61 +++-- kimm/models/vision_transformer.py | 155 ++++++------ kimm/models/xception.py | 28 ++- kimm/utils/model_registry.py | 2 +- kimm/utils/model_registry_test.py | 4 +- 18 files changed, 1088 insertions(+), 764 deletions(-) diff --git a/kimm/models/base_model.py b/kimm/models/base_model.py index d1ad061..06baa88 100644 --- a/kimm/models/base_model.py +++ b/kimm/models/base_model.py @@ -1,4 +1,3 @@ -import abc import pathlib import typing import urllib.parse @@ -12,6 +11,12 @@ class BaseModel(models.Model): + default_origin = ( + "https://github.com/james77777778/kimm/releases/download/0.1.0/" + ) + available_feature_keys = [] + available_weights = [] + def __init__( self, inputs, @@ -183,12 +188,6 @@ def load_pretrained_weights(self, weights_url: typing.Optional[str] = None): ) self.load_weights(weights_path) - @staticmethod - @abc.abstractmethod - def available_feature_keys(): - # TODO: add docstring - raise NotImplementedError - def get_config(self): # Don't chain to super here. The default `get_config()` for functional # models is nested and cannot be passed to BaseModel. @@ -215,6 +214,19 @@ def get_config(self): def fix_config(self, config: typing.Dict): return config - @property - def default_origin(self): - return "https://github.com/james77777778/kimm/releases/download/0.1.0/" + def get_weights_url(self, weights): + if weights is None: + return None + + for _weights, _origin, _file_name in self.available_weights: + if weights == _weights: + return f"{_origin}/{_file_name}" + + # Failed to find the weights + _available_weights_name = [ + _weights for _weights, _ in self.available_weights + ] + raise ValueError( + f"Available weights are {_available_weights_name}. " + f"Received weights={weights}" + ) diff --git a/kimm/models/convmixer.py b/kimm/models/convmixer.py index fe2212b..ce93f38 100644 --- a/kimm/models/convmixer.py +++ b/kimm/models/convmixer.py @@ -52,6 +52,9 @@ def __init__( activation: str = "relu", **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) inputs = self.determine_input_tensor( @@ -100,10 +103,6 @@ def __init__( self.kernel_size = kernel_size self.activation = activation - @staticmethod - def available_feature_keys(): - raise NotImplementedError - def get_config(self): config = super().get_config() config.update( @@ -136,6 +135,15 @@ def fix_config(self, config): class ConvMixer736D32(ConvMixer): + available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(32)]] + available_weights = [ + ( + "imagenet", + ConvMixer.default_origin, + "convmixer736d32_convmixer_768_32.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -150,10 +158,6 @@ def __init__( name: str = "ConvMixer736D32", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convmixer736d32_convmixer_768_32.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 32, 768, @@ -173,14 +177,17 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM"] - feature_keys.extend([f"BLOCK{i}" for i in range(32)]) - return feature_keys - class ConvMixer1024D20(ConvMixer): + available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]] + available_weights = [ + ( + "imagenet", + ConvMixer.default_origin, + "convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -195,10 +202,6 @@ def __init__( name: str = "ConvMixer1024D20", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convmixer1024d20_convmixer_1024_20_ks9_p14.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 20, 1024, @@ -218,14 +221,17 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM"] - feature_keys.extend([f"BLOCK{i}" for i in range(20)]) - return feature_keys - class ConvMixer1536D20(ConvMixer): + available_feature_keys = ["STEM", *[f"BLOCK{i}" for i in range(20)]] + available_weights = [ + ( + "imagenet", + ConvMixer.default_origin, + "convmixer1536d20_convmixer_1536_20.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -240,10 +246,6 @@ def __init__( name: str = "ConvMixer1536D20", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convmixer1536d20_convmixer_1536_20.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 20, 1536, @@ -263,12 +265,6 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM"] - feature_keys.extend([f"BLOCK{i}" for i in range(20)]) - return feature_keys - add_model_to_registry(ConvMixer736D32, "imagenet") add_model_to_registry(ConvMixer1024D20, "imagenet") diff --git a/kimm/models/convnext.py b/kimm/models/convnext.py index 06fd880..c890e42 100644 --- a/kimm/models/convnext.py +++ b/kimm/models/convnext.py @@ -120,6 +120,11 @@ def apply_convnext_stage( @keras.saving.register_keras_serializable(package="kimm") class ConvNeXt(BaseModel): + available_feature_keys = [ + "STEM_S4", + *[f"BLOCK{i}_S{2**(i+2)}" for i in range(4)], + ] + def __init__( self, depths: typing.Sequence[int] = [3, 3, 9, 3], @@ -130,6 +135,9 @@ def __init__( use_conv_mlp: bool = False, **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) inputs = self.determine_input_tensor( @@ -197,12 +205,6 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate): )(x) return x - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S4"] - feature_keys.extend([f"BLOCK{i}_S{2**(i+2)}" for i in range(4)]) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -237,6 +239,14 @@ def fix_config(self, config): class ConvNeXtAtto(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextatto_convnext_atto.d2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -251,10 +261,6 @@ def __init__( name: str = "ConvNeXtAtto", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextatto_convnext_atto.d2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (2, 2, 6, 2), (40, 80, 160, 320), @@ -277,6 +283,14 @@ def __init__( class ConvNeXtFemto(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextfemto_convnext_femto.d1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -291,10 +305,6 @@ def __init__( name: str = "ConvNeXtFemto", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextfemto_convnext_femto.d1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (2, 2, 6, 2), (48, 96, 192, 384), @@ -317,6 +327,14 @@ def __init__( class ConvNeXtPico(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextpico_convnext_pico.d1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -331,10 +349,6 @@ def __init__( name: str = "ConvNeXtPico", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextpico_convnext_pico.d1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (2, 2, 6, 2), (64, 128, 256, 512), @@ -357,6 +371,14 @@ def __init__( class ConvNeXtNano(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextnano_convnext_nano.in12k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -371,10 +393,6 @@ def __init__( name: str = "ConvNeXtNano", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextnano_convnext_nano.in12k_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (2, 2, 8, 2), (80, 160, 320, 640), @@ -397,6 +415,14 @@ def __init__( class ConvNeXtTiny(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnexttiny_convnext_tiny.in12k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -411,10 +437,6 @@ def __init__( name: str = "ConvNeXtTiny", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnexttiny_convnext_tiny.in12k_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (3, 3, 9, 3), (96, 192, 384, 768), @@ -437,6 +459,14 @@ def __init__( class ConvNeXtSmall(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextsmall_convnext_small.in12k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -451,10 +481,6 @@ def __init__( name: str = "ConvNeXtSmall", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextsmall_convnext_small.in12k_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (3, 3, 27, 3), (96, 192, 384, 768), @@ -477,6 +503,14 @@ def __init__( class ConvNeXtBase(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextbase_convnext_base.fb_in22k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -491,10 +525,6 @@ def __init__( name: str = "ConvNeXtBase", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextbase_convnext_base.fb_in22k_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (3, 3, 27, 3), (128, 256, 512, 1024), @@ -517,6 +547,14 @@ def __init__( class ConvNeXtLarge(ConvNeXt): + available_weights = [ + ( + "imagenet", + ConvNeXt.default_origin, + "convnextlarge_convnext_large.fb_in22k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -531,10 +569,6 @@ def __init__( name: str = "ConvNeXtLarge", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "convnextlarge_convnext_large.fb_in22k_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( (3, 3, 27, 3), (192, 384, 768, 1536), @@ -557,6 +591,8 @@ def __init__( class ConvNeXtXLarge(ConvNeXt): + available_weights = [] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -571,7 +607,6 @@ def __init__( name: str = "ConvNeXtXLarge", **kwargs, ): - kwargs = self.fix_config(kwargs) super().__init__( (3, 3, 27, 3), (256, 512, 1024, 2048), diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py index 71d5e69..2162cd7 100644 --- a/kimm/models/densenet.py +++ b/kimm/models/densenet.py @@ -66,12 +66,20 @@ def apply_dense_transition_block( @keras.saving.register_keras_serializable(package="kimm") class DenseNet(BaseModel): + available_feature_keys = [ + "STEM_S4", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [8, 16, 32, 32])], + ] + def __init__( self, growth_rate: float = 32, num_blocks: typing.Sequence[int] = [6, 12, 24, 16], **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) inputs = self.determine_input_tensor( @@ -133,14 +141,6 @@ def __init__( self.growth_rate = growth_rate self.num_blocks = num_blocks - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S4"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [8, 16, 32, 32])] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -161,6 +161,14 @@ def fix_config(self, config: typing.Dict): class DenseNet121(DenseNet): + available_weights = [ + ( + "imagenet", + DenseNet.default_origin, + "densenet121_densenet121.ra_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -175,10 +183,6 @@ def __init__( name: str = "DenseNet121", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "densenet121_densenet121.ra_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 32, [6, 12, 24, 16], @@ -198,6 +202,14 @@ def __init__( class DenseNet161(DenseNet): + available_weights = [ + ( + "imagenet", + DenseNet.default_origin, + "densenet161_densenet161.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -212,10 +224,6 @@ def __init__( name: str = "DenseNet161", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "densenet161_densenet161.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 48, [6, 12, 36, 24], @@ -235,6 +243,14 @@ def __init__( class DenseNet169(DenseNet): + available_weights = [ + ( + "imagenet", + DenseNet.default_origin, + "densenet169_densenet169.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -249,10 +265,6 @@ def __init__( name: str = "DenseNet169", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "densenet169_densenet169.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 32, [6, 12, 32, 32], @@ -272,6 +284,14 @@ def __init__( class DenseNet201(DenseNet): + available_weights = [ + ( + "imagenet", + DenseNet.default_origin, + "densenet201_densenet201.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -286,10 +306,6 @@ def __init__( name: str = "DenseNet201", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "densenet201_densenet201.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 32, [6, 12, 48, 32], diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index d328c3f..31c0e52 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -125,6 +125,16 @@ def apply_edge_residual_block( @keras.saving.register_keras_serializable(package="kimm") class EfficientNet(BaseModel): + # for: v1, v1_lite, v2_m, v2_l, v2_xl, tinynet + # not for: v2_s, v2_base + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ], + ] + def __init__( self, width: float = 1.0, @@ -137,6 +147,9 @@ def __init__( config: str = "v1", **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + _available_configs = [ "v1", "v1_lite", @@ -270,19 +283,6 @@ def __init__( self.activation = activation self.config = config - @staticmethod - def available_feature_keys(): - # for: v1, v1_lite, v2_m, v2_l, v2_xl, tinynet - # not for: v2_s, v2_base - feature_keys = ["STEM_S2"] - feature_keys.extend( - [ - f"BLOCK{i}_S{j}" - for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) - ] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -320,6 +320,14 @@ def fix_config(self, config: typing.Dict): class EfficientNetB0(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb0_tf_efficientnet_b0.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -335,10 +343,6 @@ def __init__( name: str = "EfficientNetB0", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb0_tf_efficientnet_b0.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -367,6 +371,14 @@ def __init__( class EfficientNetB1(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb1_tf_efficientnet_b1.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -382,10 +394,6 @@ def __init__( name: str = "EfficientNetB1", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb1_tf_efficientnet_b1.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -414,6 +422,14 @@ def __init__( class EfficientNetB2(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb2_tf_efficientnet_b2.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -429,10 +445,6 @@ def __init__( name: str = "EfficientNetB2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb2_tf_efficientnet_b2.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.1, @@ -461,6 +473,14 @@ def __init__( class EfficientNetB3(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb3_tf_efficientnet_b3.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -476,10 +496,6 @@ def __init__( name: str = "EfficientNetB3", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb3_tf_efficientnet_b3.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.2, @@ -508,6 +524,14 @@ def __init__( class EfficientNetB4(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb4_tf_efficientnet_b4.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -523,10 +547,6 @@ def __init__( name: str = "EfficientNetB4", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb4_tf_efficientnet_b4.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.4, @@ -555,6 +575,14 @@ def __init__( class EfficientNetB5(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb5_tf_efficientnet_b5.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -570,10 +598,6 @@ def __init__( name: str = "EfficientNetB5", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb5_tf_efficientnet_b5.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.6, @@ -602,6 +626,14 @@ def __init__( class EfficientNetB6(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb6_tf_efficientnet_b6.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -617,10 +649,6 @@ def __init__( name: str = "EfficientNetB6", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb6_tf_efficientnet_b6.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.8, @@ -649,6 +677,14 @@ def __init__( class EfficientNetB7(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetb7_tf_efficientnet_b7.ns_jft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -664,10 +700,6 @@ def __init__( name: str = "EfficientNetB7", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetb7_tf_efficientnet_b7.ns_jft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 2.0, @@ -696,6 +728,14 @@ def __init__( class EfficientNetLiteB0(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetliteb0_tf_efficientnet_lite0.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -711,10 +751,6 @@ def __init__( name: str = "EfficientNetLiteB0", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetliteb0_tf_efficientnet_lite0.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -743,6 +779,14 @@ def __init__( class EfficientNetLiteB1(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetliteb1_tf_efficientnet_lite1.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -758,10 +802,6 @@ def __init__( name: str = "EfficientNetLiteB1", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetliteb1_tf_efficientnet_lite1.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -790,6 +830,14 @@ def __init__( class EfficientNetLiteB2(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetliteb2_tf_efficientnet_lite2.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -805,10 +853,6 @@ def __init__( name: str = "EfficientNetLiteB2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetliteb2_tf_efficientnet_lite2.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.1, @@ -837,6 +881,14 @@ def __init__( class EfficientNetLiteB3(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetliteb3_tf_efficientnet_lite3.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -852,10 +904,6 @@ def __init__( name: str = "EfficientNetLiteB3", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetliteb3_tf_efficientnet_lite3.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.2, @@ -884,6 +932,14 @@ def __init__( class EfficientNetLiteB4(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetliteb4_tf_efficientnet_lite4.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -899,10 +955,6 @@ def __init__( name: str = "EfficientNetLiteB4", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetliteb4_tf_efficientnet_lite4.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.4, @@ -931,6 +983,18 @@ def __init__( class EfficientNetV2S(EfficientNet): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2s_tf_efficientnetv2_s.in21k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -946,12 +1010,6 @@ def __init__( name: str = "EfficientNetV2S", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "efficientnetv2s_tf_efficientnetv2_s.in21k_ft_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -978,16 +1036,16 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class EfficientNetV2M(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2m_tf_efficientnetv2_m.in21k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1003,12 +1061,6 @@ def __init__( name: str = "EfficientNetV2M", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "efficientnetv2m_tf_efficientnetv2_m.in21k_ft_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1037,6 +1089,14 @@ def __init__( class EfficientNetV2L(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2l_tf_efficientnetv2_l.in21k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1052,12 +1112,6 @@ def __init__( name: str = "EfficientNetV2L", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "efficientnetv2l_tf_efficientnetv2_l.in21k_ft_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1086,6 +1140,14 @@ def __init__( class EfficientNetV2XL(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2xl_tf_efficientnetv2_xl.in21k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1101,12 +1163,6 @@ def __init__( name: str = "EfficientNetV2XL", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "efficientnetv2xl_tf_efficientnetv2_xl.in21k_ft_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1135,6 +1191,18 @@ def __init__( class EfficientNetV2B0(EfficientNet): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2b0_tf_efficientnetv2_b0.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1150,10 +1218,6 @@ def __init__( name: str = "EfficientNetV2B0", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetv2b0_tf_efficientnetv2_b0.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1180,16 +1244,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class EfficientNetV2B1(EfficientNet): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2b1_tf_efficientnetv2_b1.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1205,10 +1273,6 @@ def __init__( name: str = "EfficientNetV2B1", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetv2b1_tf_efficientnetv2_b1.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1235,16 +1299,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class EfficientNetV2B2(EfficientNet): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2b2_tf_efficientnetv2_b2.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1260,10 +1328,6 @@ def __init__( name: str = "EfficientNetV2B2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetv2b2_tf_efficientnetv2_b2.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.1, @@ -1291,16 +1355,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class EfficientNetV2B3(EfficientNet): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "efficientnetv2b3_tf_efficientnetv2_b3.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1316,10 +1384,6 @@ def __init__( name: str = "EfficientNetV2B3", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "efficientnetv2b3_tf_efficientnetv2_b3.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.2, @@ -1347,16 +1411,16 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class TinyNetA(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "tinyneta_tinynet_a.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1372,10 +1436,6 @@ def __init__( name: str = "TinyNetA", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "tinyneta_tinynet_a.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.0, 1.2, @@ -1402,6 +1462,14 @@ def __init__( class TinyNetB(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "tinynetb_tinynet_b.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1417,10 +1485,6 @@ def __init__( name: str = "TinyNetB", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "tinynetb_tinynet_b.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.75, 1.1, @@ -1447,6 +1511,14 @@ def __init__( class TinyNetC(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "tinynetc_tinynet_c.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1462,10 +1534,6 @@ def __init__( name: str = "TinyNetC", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "tinynetc_tinynet_c.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.54, 0.85, @@ -1492,6 +1560,14 @@ def __init__( class TinyNetD(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "tinynetd_tinynet_d.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1507,10 +1583,6 @@ def __init__( name: str = "TinyNetD", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "tinynetd_tinynet_d.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.54, 0.695, @@ -1537,6 +1609,14 @@ def __init__( class TinyNetE(EfficientNet): + available_weights = [ + ( + "imagenet", + EfficientNet.default_origin, + "tinynete_tinynet_e.in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1552,10 +1632,6 @@ def __init__( name: str = "TinyNetE", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "tinynete_tinynet_e.in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.51, 0.6, diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index 6dd923f..01cb8a8 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -230,6 +230,14 @@ def apply_ghost_bottleneck( @keras.saving.register_keras_serializable(package="kimm") class GhostNet(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(9), [2, 4, 4, 8, 8, 16, 16, 32, 32]) + ], + ] + def __init__( self, width: float = 1.0, @@ -237,6 +245,9 @@ def __init__( version: typing.Literal["v1", "v2"] = "v1", **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + _available_configs = ["default"] if config == "default": _config = DEFAULT_CONFIG @@ -332,17 +343,6 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate): )(x) return x - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [ - f"BLOCK{i}_S{j}" - for i, j in zip(range(9), [2, 4, 4, 8, 8, 16, 16, 32, 32]) - ] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -367,6 +367,8 @@ def fix_config(self, config): class GhostNet050(GhostNet): + available_weights = [] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -382,7 +384,6 @@ def __init__( name: str = "GhostNet050", **kwargs, ): - kwargs = self.fix_config(kwargs) super().__init__( 0.5, config, @@ -402,6 +403,14 @@ def __init__( class GhostNet100(GhostNet): + available_weights = [ + ( + "imagenet", + GhostNet.default_origin, + "ghostnet100_ghostnet_100.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -417,10 +426,6 @@ def __init__( name: str = "GhostNet100", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "ghostnet100_ghostnet_100.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.0, config, @@ -440,6 +445,8 @@ def __init__( class GhostNet130(GhostNet): + available_weights = [] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -455,7 +462,6 @@ def __init__( name: str = "GhostNet130", **kwargs, ): - kwargs = self.fix_config(kwargs) super().__init__( 1.3, config, @@ -475,6 +481,14 @@ def __init__( class GhostNet100V2(GhostNet): + available_weights = [ + ( + "imagenet", + GhostNet.default_origin, + "ghostnet100v2_ghostnetv2_100.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -490,10 +504,6 @@ def __init__( name: str = "GhostNet100V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "ghostnet100v2_ghostnetv2_100.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.0, config, @@ -513,6 +523,14 @@ def __init__( class GhostNet130V2(GhostNet): + available_weights = [ + ( + "imagenet", + GhostNet.default_origin, + "ghostnet130v2_ghostnetv2_130.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -528,10 +546,6 @@ def __init__( name: str = "GhostNet130V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "ghostnet130v2_ghostnetv2_130.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.3, config, @@ -551,6 +565,14 @@ def __init__( class GhostNet160V2(GhostNet): + available_weights = [ + ( + "imagenet", + GhostNet.default_origin, + "ghostnet160v2_ghostnetv2_160.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -566,10 +588,6 @@ def __init__( name: str = "GhostNet160V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "ghostnet160v2_ghostnetv2_160.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.6, config, diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index 50fee21..fddc500 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -203,7 +203,15 @@ def apply_inception_aux_block(inputs, classes, name="inception_aux_block"): @keras.saving.register_keras_serializable(package="kimm") class InceptionV3Base(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + def __init__(self, has_aux_logits: bool = False, **kwargs): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs, 299) inputs = self.determine_input_tensor( @@ -264,14 +272,6 @@ def __init__(self, has_aux_logits: bool = False, **kwargs): # All references to `self` below this line self.has_aux_logits = has_aux_logits - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update({"has_aux_logits": self.has_aux_logits}) @@ -282,6 +282,19 @@ def fix_config(self, config: typing.Dict): class InceptionV3(InceptionV3Base): + available_weights = [ + ( + "imagenet_aux_logits", + InceptionV3Base.default_origin, + "inceptionv3_inception_v3.gluon_in1k_aux_logits.keras", + ), + ( + "imagenet_no_aux_logits", + InceptionV3Base.default_origin, + "inceptionv3_inception_v3.gluon_in1k_no_aux_logits.keras", + ), + ] + def __init__( self, has_aux_logits: bool = False, @@ -297,17 +310,11 @@ def __init__( name: str = "InceptionV3", **kwargs, ): - kwargs = self.fix_config(kwargs) if weights == "imagenet": if has_aux_logits: - file_name = ( - "inceptionv3_inception_v3.gluon_in1k_aux_logits.keras" - ) + weights = f"{weights}_aux_logits" else: - file_name = ( - "inceptionv3_inception_v3.gluon_in1k_no_aux_logits.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" + weights = f"{weights}_no_aux_logits" super().__init__( has_aux_logits, input_tensor=input_tensor, diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 2af2426..2f33736 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -24,6 +24,14 @@ @keras.saving.register_keras_serializable(package="kimm") class MobileNetV2(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ], + ] + def __init__( self, width: float = 1.0, @@ -32,6 +40,9 @@ def __init__( config: typing.Literal["default"] = "default", **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + _available_configs = ["default"] if config == "default": _config = DEFAULT_CONFIG @@ -111,17 +122,6 @@ def __init__( self.fix_stem_and_head_channels = fix_stem_and_head_channels self.config = config - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [ - f"BLOCK{i}_S{j}" - for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) - ] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -152,6 +152,14 @@ def fix_config(self, config): class MobileNet050V2(MobileNetV2): + available_weights = [ + ( + "imagenet", + MobileNetV2.default_origin, + "mobilenet050v2_mobilenetv2_050.lamb_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -167,10 +175,6 @@ def __init__( name: str = "MobileNet050V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilenet050v2_mobilenetv2_050.lamb_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.5, 1.0, @@ -191,6 +195,14 @@ def __init__( class MobileNet100V2(MobileNetV2): + available_weights = [ + ( + "imagenet", + MobileNetV2.default_origin, + "mobilenet100v2_mobilenetv2_100.ra_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -206,10 +218,6 @@ def __init__( name: str = "MobileNet100V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilenet100v2_mobilenetv2_100.ra_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.0, 1.0, @@ -230,6 +238,14 @@ def __init__( class MobileNet110V2(MobileNetV2): + available_weights = [ + ( + "imagenet", + MobileNetV2.default_origin, + "mobilenet110v2_mobilenetv2_110d.ra_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -245,10 +261,6 @@ def __init__( name: str = "MobileNet110V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilenet110v2_mobilenetv2_110d.ra_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.1, 1.2, @@ -269,6 +281,14 @@ def __init__( class MobileNet120V2(MobileNetV2): + available_weights = [ + ( + "imagenet", + MobileNetV2.default_origin, + "mobilenet120v2_mobilenetv2_120d.ra_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -284,10 +304,6 @@ def __init__( name: str = "MobileNet120V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilenet120v2_mobilenetv2_120d.ra_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.2, 1.4, @@ -308,6 +324,14 @@ def __init__( class MobileNet140V2(MobileNetV2): + available_weights = [ + ( + "imagenet", + MobileNetV2.default_origin, + "mobilenet140v2_mobilenetv2_140.ra_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -323,10 +347,6 @@ def __init__( name: str = "MobileNet140V2", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilenet140v2_mobilenetv2_140.ra_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.4, 1.0, diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index 92b812d..af1daea 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -90,6 +90,9 @@ def __init__( minimal: bool = False, **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + _available_configs = ["small", "large", "lcnet"] if config == "small": _config = DEFAULT_SMALL_CONFIG @@ -270,10 +273,6 @@ def build_top( )(x) return x - @staticmethod - def available_feature_keys(): - raise NotImplementedError() - def get_config(self): config = super().get_config() config.update( @@ -306,6 +305,18 @@ def fix_config(self, config): class MobileNet050V3Small(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + "mobilenet050v3small_mobilenetv3_small_050.lamb_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -321,12 +332,6 @@ def __init__( name: str = "MobileNet050V3Small", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "mobilenet050v3small_mobilenetv3_small_050.lamb_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.5, 1.0, @@ -345,16 +350,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] - ) - return feature_keys - class MobileNet075V3Small(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + "mobilenet075v3small_mobilenetv3_small_075.lamb_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -370,12 +379,6 @@ def __init__( name: str = "MobileNet075V3Small", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "mobilenet075v3small_mobilenetv3_small_075.lamb_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 0.75, 1.0, @@ -394,16 +397,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] - ) - return feature_keys - class MobileNet100V3Small(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + "mobilenet100v3small_mobilenetv3_small_100.lamb_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -419,12 +426,6 @@ def __init__( name: str = "MobileNet100V3Small", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "mobilenet100v3small_mobilenetv3_small_100.lamb_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.0, 1.0, @@ -443,16 +444,23 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] - ) - return feature_keys - class MobileNet100V3SmallMinimal(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + ( + "mobilenet100v3smallminimal_" + "tf_mobilenetv3_small_minimal_100.in1k.keras" + ), + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -468,13 +476,6 @@ def __init__( name: str = "MobileNet100V3SmallMinimal", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "mobilenet100v3smallminimal_" - "tf_mobilenetv3_small_minimal_100.in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -497,16 +498,26 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] - ) - return feature_keys - class MobileNet100V3Large(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + ( + "mobilenet100v3large_" + "mobilenetv3_large_100.miil_in21k_ft_in1k.keras" + ), + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -522,13 +533,6 @@ def __init__( name: str = "MobileNet100V3Large", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "mobilenet100v3large_" - "mobilenetv3_large_100.miil_in21k_ft_in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 1.0, 1.0, @@ -557,19 +561,26 @@ def build_preprocessing(self, inputs, mode="imagenet"): else: return super().build_preprocessing(inputs, mode) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [ - f"BLOCK{i}_S{j}" - for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) - ] - ) - return feature_keys - class MobileNet100V3LargeMinimal(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + ( + "mobilenet100v3largeminimal_" + "tf_mobilenetv3_large_minimal_100.in1k.keras" + ), + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -586,12 +597,6 @@ def __init__( **kwargs, ): kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = ( - "mobilenet100v3largeminimal_" - "tf_mobilenetv3_large_minimal_100.in1k.keras" - ) - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -614,19 +619,14 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [ - f"BLOCK{i}_S{j}" - for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) - ] - ) - return feature_keys - class LCNet035(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -662,16 +662,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class LCNet050(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + "lcnet050_lcnet_050.ra2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -687,10 +691,6 @@ def __init__( name: str = "LCNet050", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "lcnet050_lcnet_050.ra2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 0.5, @@ -710,16 +710,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class LCNet075(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + "lcnet075_lcnet_075.ra2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -735,10 +739,6 @@ def __init__( name: str = "LCNet075", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "lcnet075_lcnet_075.ra2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 0.75, @@ -758,16 +758,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class LCNet100(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])], + ] + available_weights = [ + ( + "imagenet", + MobileNetV3.default_origin, + "lcnet100_lcnet_100.ra2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -783,10 +787,6 @@ def __init__( name: str = "LCNet100", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "lcnet100_lcnet_100.ra2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -806,16 +806,17 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [2, 4, 8, 16, 16, 32])] - ) - return feature_keys - class LCNet150(MobileNetV3): + available_feature_keys = [ + "STEM_S2", + *[ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ], + ] + available_weights = [] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -851,17 +852,6 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [ - f"BLOCK{i}_S{j}" - for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) - ] - ) - return feature_keys - add_model_to_registry(MobileNet050V3Small, "imagenet") add_model_to_registry(MobileNet075V3Small, "imagenet") diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 882e272..31342fb 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -162,6 +162,11 @@ def apply_mobilevit_block( @keras.saving.register_keras_serializable(package="kimm") class MobileViT(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(5), [2, 4, 8, 16, 32])], + ] + def __init__( self, stem_channels: int = 16, @@ -170,6 +175,9 @@ def __init__( config: str = "v1_s", **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + _available_configs = ["v1_s", "v1_xs", "v1_xss"] if config == "v1_s": _config = DEFAULT_V1_S_CONFIG @@ -258,14 +266,6 @@ def __init__( self.activation = activation self.config = config - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(5), [2, 4, 8, 16, 32])] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -291,6 +291,14 @@ def fix_config(self, config): class MobileViTS(MobileViT): + available_weights = [ + ( + "imagenet", + MobileViT.default_origin, + "mobilevits_mobilevit_s.cvnets_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -306,10 +314,6 @@ def __init__( name="MobileViTS", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilevits_mobilevit_s.cvnets_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 16, 640, @@ -330,6 +334,14 @@ def __init__( class MobileViTXS(MobileViT): + available_weights = [ + ( + "imagenet", + MobileViT.default_origin, + "mobilevitxs_mobilevit_xs.cvnets_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -345,10 +357,6 @@ def __init__( name="MobileViTXS", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilevitxs_mobilevit_xs.cvnets_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 16, 384, @@ -369,6 +377,14 @@ def __init__( class MobileViTXXS(MobileViT): + available_weights = [ + ( + "imagenet", + MobileViT.default_origin, + "mobilevitxxs_mobilevit_xxs.cvnets_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -384,10 +400,6 @@ def __init__( name="MobileViTXXS", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "mobilevitxxs_mobilevit_xxs.cvnets_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 16, 320, diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index 2e3bc41..c5bfddf 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -379,7 +379,7 @@ def test_model_feature_extractor( self.assertIsInstance(y, dict) self.assertContainsSubset( - model_class.available_feature_keys(), list(y.keys()) + model_class.available_feature_keys, list(y.keys()) ) for feature_info in features: name, shape = feature_info diff --git a/kimm/models/regnet.py b/kimm/models/regnet.py index a67d23e..c0263a8 100644 --- a/kimm/models/regnet.py +++ b/kimm/models/regnet.py @@ -143,6 +143,11 @@ def apply_bottleneck_block( @keras.saving.register_keras_serializable(package="kimm") class RegNet(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + def __init__( self, w0: int = 80, @@ -153,6 +158,9 @@ def __init__( se_ratio: float = 0.0, **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + per_stage_config = _generate_regnet(w0, wa, wm, group_size, depth) input_tensor = kwargs.pop("input_tensor", None) @@ -202,14 +210,6 @@ def __init__( self.depth = depth self.se_ratio = se_ratio - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -237,6 +237,14 @@ def fix_config(self, config): class RegNetX002(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx002_regnetx_002.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -251,10 +259,6 @@ def __init__( name: str = "RegNetX002", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx002_regnetx_002.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 24, 36.44, @@ -276,6 +280,14 @@ def __init__( class RegNetY002(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety002_regnety_002.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -290,10 +302,6 @@ def __init__( name: str = "RegNetY002", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety002_regnety_002.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 24, 36.44, @@ -316,6 +324,14 @@ def __init__( class RegNetX004(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx004_regnetx_004.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -330,10 +346,6 @@ def __init__( name: str = "RegNetX004", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx004_regnetx_004.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 24, 24.48, @@ -355,6 +367,14 @@ def __init__( class RegNetY004(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety004_regnety_004.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -369,10 +389,6 @@ def __init__( name: str = "RegNetY004", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety004_regnety_004.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 48, 27.89, @@ -395,6 +411,14 @@ def __init__( class RegNetX006(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx006_regnetx_006.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -409,10 +433,6 @@ def __init__( name: str = "RegNetX006", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx006_regnetx_006.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 48, 36.97, @@ -434,6 +454,14 @@ def __init__( class RegNetY006(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety006_regnety_006.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -448,10 +476,6 @@ def __init__( name: str = "RegNetY006", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety006_regnety_006.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 48, 32.54, @@ -474,6 +498,14 @@ def __init__( class RegNetX008(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx008_regnetx_008.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -488,10 +520,6 @@ def __init__( name: str = "RegNetX008", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx008_regnetx_008.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 56, 35.73, @@ -513,6 +541,14 @@ def __init__( class RegNetY008(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety008_regnety_008.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -527,10 +563,6 @@ def __init__( name: str = "RegNetY008", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety008_regnety_008.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 56, 38.84, @@ -553,6 +585,14 @@ def __init__( class RegNetX016(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx016_regnetx_016.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -567,10 +607,6 @@ def __init__( name: str = "RegNetX016", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx016_regnetx_016.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 80, 34.01, @@ -592,6 +628,14 @@ def __init__( class RegNetY016(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety016_regnety_016.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -606,10 +650,6 @@ def __init__( name: str = "RegNetY016", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety016_regnety_016.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 48, 20.71, @@ -632,6 +672,14 @@ def __init__( class RegNetX032(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx032_regnetx_032.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -646,10 +694,6 @@ def __init__( name: str = "RegNetX032", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx032_regnetx_032.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 88, 26.31, @@ -671,6 +715,14 @@ def __init__( class RegNetY032(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety032_regnety_032.ra_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -685,10 +737,6 @@ def __init__( name: str = "RegNetY032", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety032_regnety_032.ra_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 80, 42.63, @@ -711,6 +759,14 @@ def __init__( class RegNetX040(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx040_regnetx_040.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -725,10 +781,6 @@ def __init__( name: str = "RegNetX040", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx040_regnetx_040.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 96, 38.65, @@ -750,6 +802,14 @@ def __init__( class RegNetY040(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety040_regnety_040.ra3_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -764,10 +824,6 @@ def __init__( name: str = "RegNetY040", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety040_regnety_040.ra3_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 96, 31.41, @@ -790,6 +846,14 @@ def __init__( class RegNetX064(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx064_regnetx_064.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -804,10 +868,6 @@ def __init__( name: str = "RegNetX064", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx064_regnetx_064.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 184, 60.83, @@ -829,6 +889,14 @@ def __init__( class RegNetY064(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety064_regnety_064.ra3_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -843,10 +911,6 @@ def __init__( name: str = "RegNetY064", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety064_regnety_064.ra3_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 112, 33.22, @@ -869,6 +933,14 @@ def __init__( class RegNetX080(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx080_regnetx_080.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -883,10 +955,6 @@ def __init__( name: str = "RegNetX080", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx080_regnetx_080.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 80, 49.56, @@ -908,6 +976,14 @@ def __init__( class RegNetY080(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety080_regnety_080.ra3_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -922,10 +998,6 @@ def __init__( name: str = "RegNetY080", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety080_regnety_080.ra3_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 192, 76.82, @@ -948,6 +1020,14 @@ def __init__( class RegNetX120(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx120_regnetx_120.pycls_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -962,10 +1042,6 @@ def __init__( name: str = "RegNetX120", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx120_regnetx_120.pycls_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 168, 73.36, @@ -987,6 +1063,14 @@ def __init__( class RegNetY120(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety120_regnety_120.sw_in12k_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1001,10 +1085,6 @@ def __init__( name: str = "RegNetY120", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety120_regnety_120.sw_in12k_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 168, 73.36, @@ -1027,6 +1107,14 @@ def __init__( class RegNetX160(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx160_regnetx_160.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1041,10 +1129,6 @@ def __init__( name: str = "RegNetX160", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx160_regnetx_160.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 216, 55.59, @@ -1066,6 +1150,14 @@ def __init__( class RegNetY160(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety160_regnety_160.swag_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1080,10 +1172,6 @@ def __init__( name: str = "RegNetY160", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety160_regnety_160.swag_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 200, 106.23, @@ -1106,6 +1194,14 @@ def __init__( class RegNetX320(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnetx320_regnetx_320.tv2_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1120,10 +1216,6 @@ def __init__( name: str = "RegNetX320", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnetx320_regnetx_320.tv2_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 320, 69.86, @@ -1145,6 +1237,14 @@ def __init__( class RegNetY320(RegNet): + available_weights = [ + ( + "imagenet", + RegNet.default_origin, + "regnety320_regnety_320.swag_ft_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -1159,10 +1259,6 @@ def __init__( name: str = "RegNetY320", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "regnety320_regnety_320.swag_ft_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 232, 115.89, diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 43fe944..90940b3 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -105,12 +105,20 @@ def apply_bottleneck_block( @keras.saving.register_keras_serializable(package="kimm") class ResNet(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + def __init__( self, block_fn: typing.Literal["basic", "bottleneck"], num_blocks: typing.Sequence[int], **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + if block_fn not in ("basic", "bottleneck"): raise ValueError( "`block_fn` must be one of ('basic', 'bottelneck'). " @@ -172,14 +180,6 @@ def __init__( self.block_fn = block_fn self.num_blocks = num_blocks - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] - ) - return feature_keys - def get_config(self): config = super().get_config() config.update( @@ -200,6 +200,14 @@ def fix_config(self, config): class ResNet18(ResNet): + available_weights = [ + ( + "imagenet", + ResNet.default_origin, + "resnet18_resnet18.a1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -214,10 +222,6 @@ def __init__( name: str = "ResNet18", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "resnet18_resnet18.a1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "basic", [2, 2, 2, 2], @@ -236,6 +240,14 @@ def __init__( class ResNet34(ResNet): + available_weights = [ + ( + "imagenet", + ResNet.default_origin, + "resnet34_resnet34.a1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -250,10 +262,6 @@ def __init__( name: str = "ResNet34", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "resnet34_resnet34.a1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "basic", [3, 4, 6, 3], @@ -272,6 +280,14 @@ def __init__( class ResNet50(ResNet): + available_weights = [ + ( + "imagenet", + ResNet.default_origin, + "resnet50_resnet50.a1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -286,10 +302,6 @@ def __init__( name: str = "ResNet50", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "resnet50_resnet50.a1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "bottleneck", [3, 4, 6, 3], @@ -308,6 +320,14 @@ def __init__( class ResNet101(ResNet): + available_weights = [ + ( + "imagenet", + ResNet.default_origin, + "resnet101_resnet101.a1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -322,10 +342,6 @@ def __init__( name: str = "ResNet101", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "resnet101_resnet101.a1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "bottleneck", [3, 4, 23, 3], @@ -344,6 +360,14 @@ def __init__( class ResNet152(ResNet): + available_weights = [ + ( + "imagenet", + ResNet.default_origin, + "resnet152_resnet152.a1_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -358,10 +382,6 @@ def __init__( name: str = "ResNet152", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "resnet152_resnet152.a1_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "bottleneck", [3, 8, 36, 3], diff --git a/kimm/models/vgg.py b/kimm/models/vgg.py index 77775d5..5ec17ea 100644 --- a/kimm/models/vgg.py +++ b/kimm/models/vgg.py @@ -108,7 +108,14 @@ def apply_conv_mlp_layer( @keras.saving.register_keras_serializable(package="kimm") class VGG(BaseModel): + available_feature_keys = [ + *[f"BLOCK{i}_S{j}" for i, j in zip(range(6), [1, 2, 4, 8, 16, 32])], + ] + def __init__(self, config: typing.Union[str, typing.List], **kwargs): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + _available_configs = ["vgg11", "vgg13", "vgg16", "vgg19"] if config == "vgg11": _config = DEFAULT_VGG11_CONFIG @@ -178,12 +185,6 @@ def __init__(self, config: typing.Union[str, typing.List], **kwargs): # All references to `self` below this line self.config = config - @staticmethod - def available_feature_keys(): - return [ - f"BLOCK{i}_S{j}" for i, j in zip(range(6), [1, 2, 4, 8, 16, 32]) - ] - def get_config(self): config = super().get_config() config.update({"config": self.config}) @@ -202,6 +203,14 @@ def fix_config(self, config: typing.Dict): class VGG11(VGG): + available_weights = [ + ( + "imagenet", + VGG.default_origin, + "vgg11_vgg11_bn.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -216,10 +225,6 @@ def __init__( name: str = "VGG11", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "vgg11_vgg11_bn.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "vgg11", input_tensor=input_tensor, @@ -237,6 +242,14 @@ def __init__( class VGG13(VGG): + available_weights = [ + ( + "imagenet", + VGG.default_origin, + "vgg13_vgg13_bn.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -251,10 +264,6 @@ def __init__( name: str = "VGG13", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "vgg13_vgg13_bn.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "vgg13", input_tensor=input_tensor, @@ -272,6 +281,14 @@ def __init__( class VGG16(VGG): + available_weights = [ + ( + "imagenet", + VGG.default_origin, + "vgg16_vgg16_bn.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -286,10 +303,6 @@ def __init__( name: str = "VGG16", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "vgg16_vgg16_bn.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "vgg16", input_tensor=input_tensor, @@ -307,6 +320,14 @@ def __init__( class VGG19(VGG): + available_weights = [ + ( + "imagenet", + VGG.default_origin, + "vgg19_vgg19_bn.tv_in1k.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -321,10 +342,6 @@ def __init__( name: str = "VGG19", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "vgg19_vgg19_bn.tv_in1k.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( "vgg19", input_tensor=input_tensor, diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index 5da247d..2fdf81d 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -23,6 +23,9 @@ def __init__( pos_dropout_rate: float = 0.0, **kwargs, ): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs, 384) if self._pooling is not None: @@ -100,10 +103,6 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate): )(x) return x - @staticmethod - def available_feature_keys(): - raise NotImplementedError() - def get_config(self): config = super().get_config() config.update( @@ -142,6 +141,18 @@ def fix_config(self, config): class VisionTransformerTiny16(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(12)], + ] + available_weights = [ + ( + "imagenet", + VisionTransformer.default_origin, + "visiontransformertiny16_vit_tiny_patch16_384.keras", + ) + ] + def __init__( self, mlp_ratio: float = 4.0, @@ -160,10 +171,6 @@ def __init__( name: str = "VisionTransformerTiny16", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "visiontransformertiny16_vit_tiny_patch16_384.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 16, 192, @@ -186,14 +193,14 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(12)]) - return feature_keys - class VisionTransformerTiny32(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(12)], + ] + available_weights = [] + def __init__( self, mlp_ratio: float = 4.0, @@ -235,14 +242,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(12)]) - return feature_keys - class VisionTransformerSmall16(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(12)], + ] + available_weights = [ + ( + "imagenet", + VisionTransformer.default_origin, + "visiontransformersmall16_vit_small_patch16_384.keras", + ) + ] + def __init__( self, mlp_ratio: float = 4.0, @@ -261,10 +274,6 @@ def __init__( name: str = "VisionTransformerSmall16", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "visiontransformersmall16_vit_small_patch16_384.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 16, 384, @@ -287,14 +296,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(12)]) - return feature_keys - class VisionTransformerSmall32(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(12)], + ] + available_weights = [ + ( + "imagenet", + VisionTransformer.default_origin, + "visiontransformersmall32_vit_small_patch32_384.keras", + ) + ] + def __init__( self, mlp_ratio: float = 4.0, @@ -313,10 +328,6 @@ def __init__( name: str = "VisionTransformerSmall32", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "visiontransformersmall32_vit_small_patch32_384.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 32, 384, @@ -339,14 +350,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(12)]) - return feature_keys - class VisionTransformerBase16(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(12)], + ] + available_weights = [ + ( + "imagenet", + VisionTransformer.default_origin, + "visiontransformerbase16_vit_base_patch16_384.keras", + ) + ] + def __init__( self, mlp_ratio: float = 4.0, @@ -365,10 +382,6 @@ def __init__( name: str = "VisionTransformerBase16", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "visiontransformerbase16_vit_base_patch16_384.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 16, 768, @@ -391,14 +404,20 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(12)]) - return feature_keys - class VisionTransformerBase32(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(12)], + ] + available_weights = [ + ( + "imagenet", + VisionTransformer.default_origin, + "visiontransformerbase32_vit_base_patch32_384.keras", + ) + ] + def __init__( self, mlp_ratio: float = 4.0, @@ -417,10 +436,6 @@ def __init__( name: str = "VisionTransformerBase32", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "visiontransformerbase32_vit_base_patch32_384.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( 32, 768, @@ -443,14 +458,14 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(12)]) - return feature_keys - class VisionTransformerLarge16(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(24)], + ] + available_weights = [] + def __init__( self, mlp_ratio: float = 4.0, @@ -469,7 +484,6 @@ def __init__( name: str = "VisionTransformerLarge16", **kwargs, ): - kwargs = self.fix_config(kwargs) super().__init__( 16, 1024, @@ -492,14 +506,14 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(24)]) - return feature_keys - class VisionTransformerLarge32(VisionTransformer): + available_feature_keys = [ + "EMBEDDING", + *[f"BLOCK{i}" for i in range(24)], + ] + available_weights = [] + def __init__( self, mlp_ratio: float = 4.0, @@ -518,7 +532,6 @@ def __init__( name: str = "VisionTransformerLarge32", **kwargs, ): - kwargs = self.fix_config(kwargs) super().__init__( 32, 1024, @@ -541,12 +554,6 @@ def __init__( **kwargs, ) - @staticmethod - def available_feature_keys(): - feature_keys = ["EMBEDDING"] - feature_keys.extend([f"BLOCK{i}" for i in range(24)]) - return feature_keys - add_model_to_registry(VisionTransformerTiny16, "imagenet") add_model_to_registry(VisionTransformerTiny32) diff --git a/kimm/models/xception.py b/kimm/models/xception.py index 21fcaa6..90517f6 100644 --- a/kimm/models/xception.py +++ b/kimm/models/xception.py @@ -66,7 +66,15 @@ def apply_xception_block( @keras.saving.register_keras_serializable(package="kimm") class XceptionBase(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + def __init__(self, **kwargs): + kwargs = self.fix_config(kwargs) + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + input_tensor = kwargs.pop("input_tensor", None) self.set_properties(kwargs) inputs = self.determine_input_tensor( @@ -133,14 +141,6 @@ def __init__(self, **kwargs): # All references to `self` below this line - @staticmethod - def available_feature_keys(): - feature_keys = ["STEM_S2"] - feature_keys.extend( - [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] - ) - return feature_keys - def get_config(self): return super().get_config() @@ -154,6 +154,14 @@ def fix_config(self, config: typing.Dict): class Xception(XceptionBase): + available_weights = [ + ( + "imagenet", + XceptionBase.default_origin, + "xception.keras", + ) + ] + def __init__( self, input_tensor: keras.KerasTensor = None, @@ -168,10 +176,6 @@ def __init__( name: str = "Xception", **kwargs, ): - kwargs = self.fix_config(kwargs) - if weights == "imagenet": - file_name = "xception.keras" - kwargs["weights_url"] = f"{self.default_origin}/{file_name}" super().__init__( input_tensor=input_tensor, input_shape=input_shape, diff --git a/kimm/utils/model_registry.py b/kimm/utils/model_registry.py index a80b93f..5f5fc40 100644 --- a/kimm/utils/model_registry.py +++ b/kimm/utils/model_registry.py @@ -48,7 +48,7 @@ def add_model_to_registry(model_cls, weights: typing.Optional[str] = None): feature_keys = [] if issubclass(model_cls, BaseModel): feature_extractor = True - feature_keys = model_cls.available_feature_keys() + feature_keys = model_cls.available_feature_keys for info in MODEL_REGISTRY: if info["name"] == model_cls.__name__: warnings.warn( diff --git a/kimm/utils/model_registry_test.py b/kimm/utils/model_registry_test.py index 7e3d592..93a1047 100644 --- a/kimm/utils/model_registry_test.py +++ b/kimm/utils/model_registry_test.py @@ -13,9 +13,7 @@ class DummyModel(models.Model): class DummyFeatureExtractor(BaseModel): - @staticmethod - def available_feature_keys(): - return ["A", "B", "C"] + available_feature_keys = ["A", "B", "C"] class ModelRegistryTest(testing.TestCase): From 42ad81298c943fd3e4dbdf18747aa1705cd14fb6 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jan 2024 20:35:01 +0800 Subject: [PATCH 2/4] Fix serialization --- kimm/models/convmixer.py | 4 +++- kimm/models/convnext.py | 10 +++++++++- kimm/models/densenet.py | 5 ++++- kimm/models/efficientnet.py | 27 ++++++++++++++++++++++++++- kimm/models/ghostnet.py | 7 ++++++- kimm/models/inception_v3.py | 2 +- kimm/models/mobilenet_v2.py | 6 +++++- kimm/models/mobilenet_v3.py | 9 ++++++++- kimm/models/mobilevit.py | 4 +++- kimm/models/regnet.py | 25 ++++++++++++++++++++++++- kimm/models/resnet.py | 6 +++++- kimm/models/vgg.py | 5 ++++- kimm/models/vision_transformer.py | 8 +++++++- kimm/models/xception.py | 2 +- 14 files changed, 106 insertions(+), 14 deletions(-) diff --git a/kimm/models/convmixer.py b/kimm/models/convmixer.py index ce93f38..d3d455a 100644 --- a/kimm/models/convmixer.py +++ b/kimm/models/convmixer.py @@ -52,7 +52,6 @@ def __init__( activation: str = "relu", **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) input_tensor = kwargs.pop("input_tensor", None) @@ -158,6 +157,7 @@ def __init__( name: str = "ConvMixer736D32", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, 768, @@ -202,6 +202,7 @@ def __init__( name: str = "ConvMixer1024D20", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 20, 1024, @@ -246,6 +247,7 @@ def __init__( name: str = "ConvMixer1536D20", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 20, 1536, diff --git a/kimm/models/convnext.py b/kimm/models/convnext.py index c890e42..f0edf67 100644 --- a/kimm/models/convnext.py +++ b/kimm/models/convnext.py @@ -135,7 +135,6 @@ def __init__( use_conv_mlp: bool = False, **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) input_tensor = kwargs.pop("input_tensor", None) @@ -261,6 +260,7 @@ def __init__( name: str = "ConvNeXtAtto", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (2, 2, 6, 2), (40, 80, 160, 320), @@ -305,6 +305,7 @@ def __init__( name: str = "ConvNeXtFemto", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (2, 2, 6, 2), (48, 96, 192, 384), @@ -349,6 +350,7 @@ def __init__( name: str = "ConvNeXtPico", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (2, 2, 6, 2), (64, 128, 256, 512), @@ -393,6 +395,7 @@ def __init__( name: str = "ConvNeXtNano", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (2, 2, 8, 2), (80, 160, 320, 640), @@ -437,6 +440,7 @@ def __init__( name: str = "ConvNeXtTiny", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (3, 3, 9, 3), (96, 192, 384, 768), @@ -481,6 +485,7 @@ def __init__( name: str = "ConvNeXtSmall", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (3, 3, 27, 3), (96, 192, 384, 768), @@ -525,6 +530,7 @@ def __init__( name: str = "ConvNeXtBase", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (3, 3, 27, 3), (128, 256, 512, 1024), @@ -569,6 +575,7 @@ def __init__( name: str = "ConvNeXtLarge", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (3, 3, 27, 3), (192, 384, 768, 1536), @@ -607,6 +614,7 @@ def __init__( name: str = "ConvNeXtXLarge", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( (3, 3, 27, 3), (256, 512, 1024, 2048), diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py index 2162cd7..3993cc3 100644 --- a/kimm/models/densenet.py +++ b/kimm/models/densenet.py @@ -77,7 +77,6 @@ def __init__( num_blocks: typing.Sequence[int] = [6, 12, 24, 16], **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) input_tensor = kwargs.pop("input_tensor", None) @@ -183,6 +182,7 @@ def __init__( name: str = "DenseNet121", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, [6, 12, 24, 16], @@ -224,6 +224,7 @@ def __init__( name: str = "DenseNet161", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 48, [6, 12, 36, 24], @@ -265,6 +266,7 @@ def __init__( name: str = "DenseNet169", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, [6, 12, 32, 32], @@ -306,6 +308,7 @@ def __init__( name: str = "DenseNet201", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, [6, 12, 48, 32], diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 31c0e52..ab6f0b4 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -147,7 +147,6 @@ def __init__( config: str = "v1", **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) _available_configs = [ @@ -343,6 +342,7 @@ def __init__( name: str = "EfficientNetB0", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -394,6 +394,7 @@ def __init__( name: str = "EfficientNetB1", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -445,6 +446,7 @@ def __init__( name: str = "EfficientNetB2", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.1, @@ -496,6 +498,7 @@ def __init__( name: str = "EfficientNetB3", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.2, @@ -547,6 +550,7 @@ def __init__( name: str = "EfficientNetB4", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.4, @@ -598,6 +602,7 @@ def __init__( name: str = "EfficientNetB5", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.6, @@ -649,6 +654,7 @@ def __init__( name: str = "EfficientNetB6", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.8, @@ -700,6 +706,7 @@ def __init__( name: str = "EfficientNetB7", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 2.0, @@ -751,6 +758,7 @@ def __init__( name: str = "EfficientNetLiteB0", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -802,6 +810,7 @@ def __init__( name: str = "EfficientNetLiteB1", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -853,6 +862,7 @@ def __init__( name: str = "EfficientNetLiteB2", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.1, @@ -904,6 +914,7 @@ def __init__( name: str = "EfficientNetLiteB3", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.2, @@ -955,6 +966,7 @@ def __init__( name: str = "EfficientNetLiteB4", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.4, @@ -1010,6 +1022,7 @@ def __init__( name: str = "EfficientNetV2S", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1061,6 +1074,7 @@ def __init__( name: str = "EfficientNetV2M", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1112,6 +1126,7 @@ def __init__( name: str = "EfficientNetV2L", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1163,6 +1178,7 @@ def __init__( name: str = "EfficientNetV2XL", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1218,6 +1234,7 @@ def __init__( name: str = "EfficientNetV2B0", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1273,6 +1290,7 @@ def __init__( name: str = "EfficientNetV2B1", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -1328,6 +1346,7 @@ def __init__( name: str = "EfficientNetV2B2", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.1, @@ -1384,6 +1403,7 @@ def __init__( name: str = "EfficientNetV2B3", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.2, @@ -1436,6 +1456,7 @@ def __init__( name: str = "TinyNetA", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.0, 1.2, @@ -1485,6 +1506,7 @@ def __init__( name: str = "TinyNetB", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.75, 1.1, @@ -1534,6 +1556,7 @@ def __init__( name: str = "TinyNetC", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.54, 0.85, @@ -1583,6 +1606,7 @@ def __init__( name: str = "TinyNetD", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.54, 0.695, @@ -1632,6 +1656,7 @@ def __init__( name: str = "TinyNetE", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.51, 0.6, diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index 01cb8a8..b932829 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -245,7 +245,6 @@ def __init__( version: typing.Literal["v1", "v2"] = "v1", **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) _available_configs = ["default"] @@ -384,6 +383,7 @@ def __init__( name: str = "GhostNet050", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.5, config, @@ -426,6 +426,7 @@ def __init__( name: str = "GhostNet100", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.0, config, @@ -462,6 +463,7 @@ def __init__( name: str = "GhostNet130", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.3, config, @@ -504,6 +506,7 @@ def __init__( name: str = "GhostNet100V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.0, config, @@ -546,6 +549,7 @@ def __init__( name: str = "GhostNet130V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.3, config, @@ -588,6 +592,7 @@ def __init__( name: str = "GhostNet160V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.6, config, diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index fddc500..55a3043 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -209,7 +209,6 @@ class InceptionV3Base(BaseModel): ] def __init__(self, has_aux_logits: bool = False, **kwargs): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) input_tensor = kwargs.pop("input_tensor", None) @@ -310,6 +309,7 @@ def __init__( name: str = "InceptionV3", **kwargs, ): + kwargs = self.fix_config(kwargs) if weights == "imagenet": if has_aux_logits: weights = f"{weights}_aux_logits" diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 2f33736..f5150f1 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -40,7 +40,6 @@ def __init__( config: typing.Literal["default"] = "default", **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) _available_configs = ["default"] @@ -175,6 +174,7 @@ def __init__( name: str = "MobileNet050V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.5, 1.0, @@ -218,6 +218,7 @@ def __init__( name: str = "MobileNet100V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.0, 1.0, @@ -261,6 +262,7 @@ def __init__( name: str = "MobileNet110V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.1, 1.2, @@ -304,6 +306,7 @@ def __init__( name: str = "MobileNet120V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.2, 1.4, @@ -347,6 +350,7 @@ def __init__( name: str = "MobileNet140V2", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.4, 1.0, diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index af1daea..5d61396 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -90,7 +90,6 @@ def __init__( minimal: bool = False, **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) _available_configs = ["small", "large", "lcnet"] @@ -332,6 +331,7 @@ def __init__( name: str = "MobileNet050V3Small", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.5, 1.0, @@ -379,6 +379,7 @@ def __init__( name: str = "MobileNet075V3Small", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 0.75, 1.0, @@ -426,6 +427,7 @@ def __init__( name: str = "MobileNet100V3Small", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.0, 1.0, @@ -476,6 +478,7 @@ def __init__( name: str = "MobileNet100V3SmallMinimal", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, @@ -533,6 +536,7 @@ def __init__( name: str = "MobileNet100V3Large", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 1.0, 1.0, @@ -691,6 +695,7 @@ def __init__( name: str = "LCNet050", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 0.5, @@ -739,6 +744,7 @@ def __init__( name: str = "LCNet075", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 0.75, @@ -787,6 +793,7 @@ def __init__( name: str = "LCNet100", **kwargs, ): + kwargs = self.fix_config(kwargs) # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 31342fb..c065d5d 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -175,7 +175,6 @@ def __init__( config: str = "v1_s", **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) _available_configs = ["v1_s", "v1_xs", "v1_xss"] @@ -314,6 +313,7 @@ def __init__( name="MobileViTS", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 640, @@ -357,6 +357,7 @@ def __init__( name="MobileViTXS", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 384, @@ -400,6 +401,7 @@ def __init__( name="MobileViTXXS", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 320, diff --git a/kimm/models/regnet.py b/kimm/models/regnet.py index c0263a8..6ab337e 100644 --- a/kimm/models/regnet.py +++ b/kimm/models/regnet.py @@ -158,7 +158,6 @@ def __init__( se_ratio: float = 0.0, **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) per_stage_config = _generate_regnet(w0, wa, wm, group_size, depth) @@ -259,6 +258,7 @@ def __init__( name: str = "RegNetX002", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 24, 36.44, @@ -302,6 +302,7 @@ def __init__( name: str = "RegNetY002", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 24, 36.44, @@ -346,6 +347,7 @@ def __init__( name: str = "RegNetX004", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 24, 24.48, @@ -389,6 +391,7 @@ def __init__( name: str = "RegNetY004", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 48, 27.89, @@ -433,6 +436,7 @@ def __init__( name: str = "RegNetX006", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 48, 36.97, @@ -476,6 +480,7 @@ def __init__( name: str = "RegNetY006", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 48, 32.54, @@ -520,6 +525,7 @@ def __init__( name: str = "RegNetX008", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 56, 35.73, @@ -563,6 +569,7 @@ def __init__( name: str = "RegNetY008", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 56, 38.84, @@ -607,6 +614,7 @@ def __init__( name: str = "RegNetX016", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 80, 34.01, @@ -650,6 +658,7 @@ def __init__( name: str = "RegNetY016", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 48, 20.71, @@ -694,6 +703,7 @@ def __init__( name: str = "RegNetX032", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 88, 26.31, @@ -737,6 +747,7 @@ def __init__( name: str = "RegNetY032", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 80, 42.63, @@ -781,6 +792,7 @@ def __init__( name: str = "RegNetX040", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 96, 38.65, @@ -824,6 +836,7 @@ def __init__( name: str = "RegNetY040", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 96, 31.41, @@ -868,6 +881,7 @@ def __init__( name: str = "RegNetX064", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 184, 60.83, @@ -911,6 +925,7 @@ def __init__( name: str = "RegNetY064", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 112, 33.22, @@ -955,6 +970,7 @@ def __init__( name: str = "RegNetX080", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 80, 49.56, @@ -998,6 +1014,7 @@ def __init__( name: str = "RegNetY080", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 192, 76.82, @@ -1042,6 +1059,7 @@ def __init__( name: str = "RegNetX120", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 168, 73.36, @@ -1085,6 +1103,7 @@ def __init__( name: str = "RegNetY120", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 168, 73.36, @@ -1129,6 +1148,7 @@ def __init__( name: str = "RegNetX160", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 216, 55.59, @@ -1172,6 +1192,7 @@ def __init__( name: str = "RegNetY160", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 200, 106.23, @@ -1216,6 +1237,7 @@ def __init__( name: str = "RegNetX320", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 320, 69.86, @@ -1259,6 +1281,7 @@ def __init__( name: str = "RegNetY320", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 232, 115.89, diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 90940b3..298dea2 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -116,7 +116,6 @@ def __init__( num_blocks: typing.Sequence[int], **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) if block_fn not in ("basic", "bottleneck"): @@ -222,6 +221,7 @@ def __init__( name: str = "ResNet18", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "basic", [2, 2, 2, 2], @@ -262,6 +262,7 @@ def __init__( name: str = "ResNet34", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "basic", [3, 4, 6, 3], @@ -302,6 +303,7 @@ def __init__( name: str = "ResNet50", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "bottleneck", [3, 4, 6, 3], @@ -342,6 +344,7 @@ def __init__( name: str = "ResNet101", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "bottleneck", [3, 4, 23, 3], @@ -382,6 +385,7 @@ def __init__( name: str = "ResNet152", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "bottleneck", [3, 8, 36, 3], diff --git a/kimm/models/vgg.py b/kimm/models/vgg.py index 5ec17ea..069426e 100644 --- a/kimm/models/vgg.py +++ b/kimm/models/vgg.py @@ -113,7 +113,6 @@ class VGG(BaseModel): ] def __init__(self, config: typing.Union[str, typing.List], **kwargs): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) _available_configs = ["vgg11", "vgg13", "vgg16", "vgg19"] @@ -225,6 +224,7 @@ def __init__( name: str = "VGG11", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "vgg11", input_tensor=input_tensor, @@ -264,6 +264,7 @@ def __init__( name: str = "VGG13", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "vgg13", input_tensor=input_tensor, @@ -303,6 +304,7 @@ def __init__( name: str = "VGG16", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "vgg16", input_tensor=input_tensor, @@ -342,6 +344,7 @@ def __init__( name: str = "VGG19", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( "vgg19", input_tensor=input_tensor, diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index 2fdf81d..82f757b 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -23,7 +23,6 @@ def __init__( pos_dropout_rate: float = 0.0, **kwargs, ): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) input_tensor = kwargs.pop("input_tensor", None) @@ -171,6 +170,7 @@ def __init__( name: str = "VisionTransformerTiny16", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 192, @@ -274,6 +274,7 @@ def __init__( name: str = "VisionTransformerSmall16", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 384, @@ -328,6 +329,7 @@ def __init__( name: str = "VisionTransformerSmall32", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, 384, @@ -382,6 +384,7 @@ def __init__( name: str = "VisionTransformerBase16", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 768, @@ -436,6 +439,7 @@ def __init__( name: str = "VisionTransformerBase32", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, 768, @@ -484,6 +488,7 @@ def __init__( name: str = "VisionTransformerLarge16", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 16, 1024, @@ -532,6 +537,7 @@ def __init__( name: str = "VisionTransformerLarge32", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( 32, 1024, diff --git a/kimm/models/xception.py b/kimm/models/xception.py index 90517f6..b9f743d 100644 --- a/kimm/models/xception.py +++ b/kimm/models/xception.py @@ -72,7 +72,6 @@ class XceptionBase(BaseModel): ] def __init__(self, **kwargs): - kwargs = self.fix_config(kwargs) kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) input_tensor = kwargs.pop("input_tensor", None) @@ -176,6 +175,7 @@ def __init__( name: str = "Xception", **kwargs, ): + kwargs = self.fix_config(kwargs) super().__init__( input_tensor=input_tensor, input_shape=input_shape, From 588956ec9b11e85a87ab066b85b36567792cdbea Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jan 2024 20:47:59 +0800 Subject: [PATCH 3/4] Update version --- kimm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kimm/__init__.py b/kimm/__init__.py index 74309d2..57a48e1 100644 --- a/kimm/__init__.py +++ b/kimm/__init__.py @@ -1,4 +1,4 @@ from kimm import models # force to add models to the registry from kimm.utils.model_registry import list_models -__version__ = "0.1.2" +__version__ = "0.1.3" From ac42239efa11d4a706fe64c3b7620314f75df37d Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jan 2024 20:57:48 +0800 Subject: [PATCH 4/4] Update `README` --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index adbc733..d74b769 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,9 @@ pip install keras kimm [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14WxYgVjlwCIO9MwqPYW-dskbTL2UHsVN?usp=sharing) ```python -import cv2 import keras from keras import ops +from keras import utils from keras.applications.imagenet_utils import decode_predictions import kimm @@ -43,15 +43,15 @@ print(kimm.list_models()) print(kimm.list_models("efficientnet", weights="imagenet")) # fuzzy search # Initialize the model with pretrained weights -model = kimm.models.EfficientNetV2B0() -image_size = model._default_size +model = kimm.models.VisionTransformerTiny16() +image_size = (model._default_size, model._default_size) # Load an image as the model input image_path = keras.utils.get_file( "african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png" ) -image = cv2.imread(image_path) -image = cv2.resize(image, (image_size, image_size)) +image = utils.load_img(image_path, target_size=image_size) +image = utils.img_to_array(image) x = ops.convert_to_tensor(image) x = ops.expand_dims(x, axis=0) @@ -62,9 +62,9 @@ print("Predicted:", decode_predictions(preds, top=3)[0]) ```bash ['ConvMixer1024D20', 'ConvMixer1536D20', 'ConvMixer736D32', 'ConvNeXtAtto', ...] -['EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2', 'EfficientNetB3', ...] -1/1 ━━━━━━━━━━━━━━━━━━━━ 11s 11s/step -Predicted: [('n02504458', 'African_elephant', 0.90578836), ('n01871265', 'tusker', 0.024864597), ('n02504013', 'Indian_elephant', 0.01161992)] +['VisionTransformerBase16', 'VisionTransformerBase32', 'VisionTransformerSmall16', ...] +1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step +Predicted: [('n02504458', 'African_elephant', 0.6895825), ('n01871265', 'tusker', 0.17934209), ('n02504013', 'Indian_elephant', 0.12927249)] ``` ### An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset