diff --git a/kimm/_src/blocks/conv2d.py b/kimm/_src/blocks/conv2d.py index 62c3d19..7f7ec37 100644 --- a/kimm/_src/blocks/conv2d.py +++ b/kimm/_src/blocks/conv2d.py @@ -2,6 +2,7 @@ from keras import backend from keras import layers +from keras.src.utils.argument_validation import standardize_tuple from kimm._src.kimm_export import kimm_export @@ -10,29 +11,33 @@ def apply_conv2d_block( inputs, filters: typing.Optional[int] = None, - kernel_size: typing.Optional[ - typing.Union[int, typing.Sequence[int]] - ] = None, + kernel_size: typing.Union[int, typing.Sequence[int]] = 1, strides: int = 1, groups: int = 1, activation: typing.Optional[str] = None, use_depthwise: bool = False, - add_skip: bool = False, + has_skip: bool = False, bn_momentum: float = 0.9, bn_epsilon: float = 1e-5, padding: typing.Optional[typing.Literal["same", "valid"]] = None, name="conv2d_block", ): + """(ZeroPadding) + Conv2D/DepthwiseConv2D + BN + (Activation).""" if kernel_size is None: raise ValueError( f"kernel_size must be passed. Received: kernel_size={kernel_size}" ) - if isinstance(kernel_size, int): - kernel_size = [kernel_size, kernel_size] + kernel_size = standardize_tuple(kernel_size, 2, "kernel_size") channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 - input_channels = inputs.shape[channels_axis] - has_skip = add_skip and strides == 1 and input_channels == filters + input_filters = inputs.shape[channels_axis] + if has_skip and (strides != 1 or input_filters != filters): + raise ValueError( + "If `has_skip=True`, strides must be 1 and `filters` must be the " + "same as input_filters. " + f"Received: strides={strides}, filters={filters}, " + f"input_filters={input_filters}" + ) x = inputs if padding is None: diff --git a/kimm/_src/blocks/depthwise_separation.py b/kimm/_src/blocks/depthwise_separation.py index fbbefb5..9ce2042 100644 --- a/kimm/_src/blocks/depthwise_separation.py +++ b/kimm/_src/blocks/depthwise_separation.py @@ -11,7 +11,7 @@ @kimm_export(parent_path=["kimm.blocks"]) def apply_depthwise_separation_block( inputs, - output_channels: int, + filters: int, depthwise_kernel_size: int = 3, pointwise_kernel_size: int = 1, strides: int = 1, @@ -21,14 +21,21 @@ def apply_depthwise_separation_block( se_gate_activation: typing.Optional[str] = "sigmoid", se_make_divisible_number: typing.Optional[int] = None, pw_activation: typing.Optional[str] = None, - skip: bool = True, + has_skip: bool = True, bn_epsilon: float = 1e-5, padding: typing.Optional[typing.Literal["same", "valid"]] = None, name: str = "depthwise_separation_block", ): + """Conv2D block + (SqueezeAndExcitation) + Conv2D.""" channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 - input_channels = inputs.shape[channels_axis] - has_skip = skip and (strides == 1 and input_channels == output_channels) + input_filters = inputs.shape[channels_axis] + if has_skip and (strides != 1 or input_filters != filters): + raise ValueError( + "If `has_skip=True`, strides must be 1 and `filters` must be the " + "same as input_filters. " + f"Received: strides={strides}, filters={filters}, " + f"input_filters={input_filters}" + ) x = inputs x = apply_conv2d_block( @@ -52,7 +59,7 @@ def apply_depthwise_separation_block( ) x = apply_conv2d_block( x, - output_channels, + filters, pointwise_kernel_size, 1, activation=pw_activation, diff --git a/kimm/_src/blocks/inverted_residual.py b/kimm/_src/blocks/inverted_residual.py index 46cbe72..a89e47e 100644 --- a/kimm/_src/blocks/inverted_residual.py +++ b/kimm/_src/blocks/inverted_residual.py @@ -12,7 +12,7 @@ @kimm_export(parent_path=["kimm.blocks"]) def apply_inverted_residual_block( inputs, - output_channels: int, + filters: int, depthwise_kernel_size: int = 3, expansion_kernel_size: int = 1, pointwise_kernel_size: int = 1, @@ -28,10 +28,11 @@ def apply_inverted_residual_block( padding: typing.Optional[typing.Literal["same", "valid"]] = None, name: str = "inverted_residual_block", ): + """Conv2D block + DepthwiseConv2D block + (SE) + Conv2D.""" channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 input_channels = inputs.shape[channels_axis] hidden_channels = make_divisible(input_channels * expansion_ratio) - has_skip = strides == 1 and input_channels == output_channels + has_skip = strides == 1 and input_channels == filters x = inputs # Point-wise expansion @@ -70,7 +71,7 @@ def apply_inverted_residual_block( # Point-wise linear projection x = apply_conv2d_block( x, - output_channels, + filters, pointwise_kernel_size, 1, activation=None, diff --git a/kimm/_src/blocks/squeeze_and_excitation.py b/kimm/_src/blocks/squeeze_and_excitation.py index 8a1cef0..9c0a204 100644 --- a/kimm/_src/blocks/squeeze_and_excitation.py +++ b/kimm/_src/blocks/squeeze_and_excitation.py @@ -17,6 +17,7 @@ def apply_se_block( se_input_channels: typing.Optional[int] = None, name: str = "se_block", ): + """Squeeze and Excitation.""" channels_axis = -1 if backend.image_data_format() == "channels_last" else -3 input_channels = inputs.shape[channels_axis] if se_input_channels is None: diff --git a/kimm/_src/blocks/transformer.py b/kimm/_src/blocks/transformer.py index f984bff..1f6172a 100644 --- a/kimm/_src/blocks/transformer.py +++ b/kimm/_src/blocks/transformer.py @@ -19,6 +19,7 @@ def apply_mlp_block( data_format: typing.Optional[str] = None, name: str = "mlp_block", ): + """Dense/Conv2D + Activation + Dense/Conv2D.""" if data_format is None: data_format = backend.image_data_format() dim_axis = -1 if data_format == "channels_last" else 1 @@ -56,6 +57,7 @@ def apply_transformer_block( activation: str = "gelu", name: str = "transformer_block", ): + """LN + Attention + LN + MLP block.""" # data_format must be "channels_last" x = inputs residual_1 = x diff --git a/kimm/_src/layers/reparameterizable_conv2d.py b/kimm/_src/layers/reparameterizable_conv2d.py index 4145d8e..8b64fc9 100644 --- a/kimm/_src/layers/reparameterizable_conv2d.py +++ b/kimm/_src/layers/reparameterizable_conv2d.py @@ -19,7 +19,7 @@ def __init__( self, filters, kernel_size, - strides=(1, 1), + strides=1, padding=None, has_skip: bool = True, has_scale: bool = True, diff --git a/kimm/_src/models/efficientnet.py b/kimm/_src/models/efficientnet.py index 2a43ba0..43bafa0 100644 --- a/kimm/_src/models/efficientnet.py +++ b/kimm/_src/models/efficientnet.py @@ -245,8 +245,17 @@ def __init__( "activation": activation, } if block_type == "ds": + has_skip = x.shape[channels_axis] == c and s == 1 x = apply_depthwise_separation_block( - x, c, k, 1, s, se, se_activation=activation, **_kwargs + x, + c, + k, + 1, + s, + se, + se_activation=activation, + has_skip=has_skip, + **_kwargs, ) elif block_type == "ir": se_c = x.shape[channels_axis] @@ -254,7 +263,10 @@ def __init__( x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs ) elif block_type == "cn": - x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs) + has_skip = x.shape[channels_axis] == c and s == 1 + x = apply_conv2d_block( + x, c, k, s, has_skip=has_skip, **_kwargs + ) elif block_type == "er": x = apply_edge_residual_block(x, c, k, 1, s, e, **_kwargs) current_stride *= s diff --git a/kimm/_src/models/hgnet.py b/kimm/_src/models/hgnet.py index 9111075..00e3242 100644 --- a/kimm/_src/models/hgnet.py +++ b/kimm/_src/models/hgnet.py @@ -267,7 +267,7 @@ def apply_high_perf_gpu_block( hidden_channels, output_channels, kernel_size, - add_skip=False, + has_skip=False, use_light_block=False, use_learnable_affine=False, aggregation="ese", @@ -329,7 +329,7 @@ def apply_high_perf_gpu_block( name=f"{name}_aggregation_0", ) x = apply_ese_module(x, output_channels, name=f"{name}_aggregation_1") - if add_skip: + if has_skip: x = layers.Add()([x, inputs]) return x @@ -375,7 +375,7 @@ def apply_high_perf_gpu_stage( hidden_channels, output_channels, kernel_size, - add_skip=False if i == 0 else True, + has_skip=False if i == 0 else True, use_light_block=use_light_block, use_learnable_affine=use_learnable_affine, aggregation=aggregation, diff --git a/kimm/_src/models/mobilenet_v2.py b/kimm/_src/models/mobilenet_v2.py index 0940911..4ceac86 100644 --- a/kimm/_src/models/mobilenet_v2.py +++ b/kimm/_src/models/mobilenet_v2.py @@ -3,6 +3,7 @@ import typing import keras +from keras import backend from kimm._src.blocks.conv2d import apply_conv2d_block from kimm._src.blocks.depthwise_separation import ( @@ -55,6 +56,10 @@ def __init__( ) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -93,8 +98,16 @@ def __init__( s = s if current_layer_idx == 0 else 1 name = f"blocks_{current_block_idx}_{current_layer_idx}" if block_type == "ds": + has_skip = x.shape[channels_axis] == c and s == 1 x = apply_depthwise_separation_block( - x, c, k, 1, s, activation="relu6", name=name + x, + c, + k, + 1, + s, + activation="relu6", + has_skip=has_skip, + name=name, ) elif block_type == "ir": x = apply_inverted_residual_block( diff --git a/kimm/_src/models/mobilenet_v3.py b/kimm/_src/models/mobilenet_v3.py index 976998e..b5970da 100644 --- a/kimm/_src/models/mobilenet_v3.py +++ b/kimm/_src/models/mobilenet_v3.py @@ -4,6 +4,7 @@ import warnings import keras +from keras import backend from keras import layers from kimm._src.blocks.conv2d import apply_conv2d_block @@ -124,6 +125,10 @@ def __init__( padding = kwargs.pop("padding", None) self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + inputs = self.determine_input_tensor( input_tensor, self._input_shape, @@ -181,6 +186,10 @@ def __init__( ), } if block_type in ("ds", "dsa"): + if block_type == "dsa": + has_skip = False + else: + has_skip = x.shape[channels_axis] == c and s == 1 x = apply_depthwise_separation_block( x, c, @@ -193,7 +202,7 @@ def __init__( se_gate_activation="hard_sigmoid", se_make_divisible_number=8, pw_activation=act if block_type == "dsa" else None, - skip=False if block_type == "dsa" else True, + has_skip=has_skip, **_kwargs, ) elif block_type == "ir":