Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings for kimm.blocks.* #50

Merged
merged 2 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions kimm/_src/blocks/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from keras import backend
from keras import layers
from keras.src.utils.argument_validation import standardize_tuple

from kimm._src.kimm_export import kimm_export

Expand All @@ -10,29 +11,33 @@
def apply_conv2d_block(
inputs,
filters: typing.Optional[int] = None,
kernel_size: typing.Optional[
typing.Union[int, typing.Sequence[int]]
] = None,
kernel_size: typing.Union[int, typing.Sequence[int]] = 1,
strides: int = 1,
groups: int = 1,
activation: typing.Optional[str] = None,
use_depthwise: bool = False,
add_skip: bool = False,
has_skip: bool = False,
bn_momentum: float = 0.9,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name="conv2d_block",
):
"""(ZeroPadding) + Conv2D/DepthwiseConv2D + BN + (Activation)."""
if kernel_size is None:
raise ValueError(
f"kernel_size must be passed. Received: kernel_size={kernel_size}"
)
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
kernel_size = standardize_tuple(kernel_size, 2, "kernel_size")

channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
has_skip = add_skip and strides == 1 and input_channels == filters
input_filters = inputs.shape[channels_axis]
if has_skip and (strides != 1 or input_filters != filters):
raise ValueError(
"If `has_skip=True`, strides must be 1 and `filters` must be the "
"same as input_filters. "
f"Received: strides={strides}, filters={filters}, "
f"input_filters={input_filters}"
)
x = inputs

if padding is None:
Expand Down
17 changes: 12 additions & 5 deletions kimm/_src/blocks/depthwise_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@kimm_export(parent_path=["kimm.blocks"])
def apply_depthwise_separation_block(
inputs,
output_channels: int,
filters: int,
depthwise_kernel_size: int = 3,
pointwise_kernel_size: int = 1,
strides: int = 1,
Expand All @@ -21,14 +21,21 @@ def apply_depthwise_separation_block(
se_gate_activation: typing.Optional[str] = "sigmoid",
se_make_divisible_number: typing.Optional[int] = None,
pw_activation: typing.Optional[str] = None,
skip: bool = True,
has_skip: bool = True,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "depthwise_separation_block",
):
"""Conv2D block + (SqueezeAndExcitation) + Conv2D."""
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
has_skip = skip and (strides == 1 and input_channels == output_channels)
input_filters = inputs.shape[channels_axis]
if has_skip and (strides != 1 or input_filters != filters):
raise ValueError(
"If `has_skip=True`, strides must be 1 and `filters` must be the "
"same as input_filters. "
f"Received: strides={strides}, filters={filters}, "
f"input_filters={input_filters}"
)

x = inputs
x = apply_conv2d_block(
Expand All @@ -52,7 +59,7 @@ def apply_depthwise_separation_block(
)
x = apply_conv2d_block(
x,
output_channels,
filters,
pointwise_kernel_size,
1,
activation=pw_activation,
Expand Down
7 changes: 4 additions & 3 deletions kimm/_src/blocks/inverted_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@kimm_export(parent_path=["kimm.blocks"])
def apply_inverted_residual_block(
inputs,
output_channels: int,
filters: int,
depthwise_kernel_size: int = 3,
expansion_kernel_size: int = 1,
pointwise_kernel_size: int = 1,
Expand All @@ -28,10 +28,11 @@ def apply_inverted_residual_block(
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "inverted_residual_block",
):
"""Conv2D block + DepthwiseConv2D block + (SE) + Conv2D."""
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels
has_skip = strides == 1 and input_channels == filters

x = inputs
# Point-wise expansion
Expand Down Expand Up @@ -70,7 +71,7 @@ def apply_inverted_residual_block(
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
filters,
pointwise_kernel_size,
1,
activation=None,
Expand Down
1 change: 1 addition & 0 deletions kimm/_src/blocks/squeeze_and_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def apply_se_block(
se_input_channels: typing.Optional[int] = None,
name: str = "se_block",
):
"""Squeeze and Excitation."""
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
if se_input_channels is None:
Expand Down
2 changes: 2 additions & 0 deletions kimm/_src/blocks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def apply_mlp_block(
data_format: typing.Optional[str] = None,
name: str = "mlp_block",
):
"""Dense/Conv2D + Activation + Dense/Conv2D."""
if data_format is None:
data_format = backend.image_data_format()
dim_axis = -1 if data_format == "channels_last" else 1
Expand Down Expand Up @@ -56,6 +57,7 @@ def apply_transformer_block(
activation: str = "gelu",
name: str = "transformer_block",
):
"""LN + Attention + LN + MLP block."""
# data_format must be "channels_last"
x = inputs
residual_1 = x
Expand Down
2 changes: 1 addition & 1 deletion kimm/_src/layers/reparameterizable_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
self,
filters,
kernel_size,
strides=(1, 1),
strides=1,
padding=None,
has_skip: bool = True,
has_scale: bool = True,
Expand Down
16 changes: 14 additions & 2 deletions kimm/_src/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,28 @@
"activation": activation,
}
if block_type == "ds":
has_skip = x.shape[channels_axis] == c and s == 1

Check warning on line 248 in kimm/_src/models/efficientnet.py

View check run for this annotation

Codecov / codecov/patch

kimm/_src/models/efficientnet.py#L248

Added line #L248 was not covered by tests
x = apply_depthwise_separation_block(
x, c, k, 1, s, se, se_activation=activation, **_kwargs
x,
c,
k,
1,
s,
se,
se_activation=activation,
has_skip=has_skip,
**_kwargs,
)
elif block_type == "ir":
se_c = x.shape[channels_axis]
x = apply_inverted_residual_block(
x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs
)
elif block_type == "cn":
x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs)
has_skip = x.shape[channels_axis] == c and s == 1
x = apply_conv2d_block(

Check warning on line 267 in kimm/_src/models/efficientnet.py

View check run for this annotation

Codecov / codecov/patch

kimm/_src/models/efficientnet.py#L266-L267

Added lines #L266 - L267 were not covered by tests
x, c, k, s, has_skip=has_skip, **_kwargs
)
elif block_type == "er":
x = apply_edge_residual_block(x, c, k, 1, s, e, **_kwargs)
current_stride *= s
Expand Down
6 changes: 3 additions & 3 deletions kimm/_src/models/hgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
hidden_channels,
output_channels,
kernel_size,
add_skip=False,
has_skip=False,
use_light_block=False,
use_learnable_affine=False,
aggregation="ese",
Expand Down Expand Up @@ -329,7 +329,7 @@
name=f"{name}_aggregation_0",
)
x = apply_ese_module(x, output_channels, name=f"{name}_aggregation_1")
if add_skip:
if has_skip:

Check warning on line 332 in kimm/_src/models/hgnet.py

View check run for this annotation

Codecov / codecov/patch

kimm/_src/models/hgnet.py#L332

Added line #L332 was not covered by tests
x = layers.Add()([x, inputs])
return x

Expand Down Expand Up @@ -375,7 +375,7 @@
hidden_channels,
output_channels,
kernel_size,
add_skip=False if i == 0 else True,
has_skip=False if i == 0 else True,
use_light_block=use_light_block,
use_learnable_affine=use_learnable_affine,
aggregation=aggregation,
Expand Down
15 changes: 14 additions & 1 deletion kimm/_src/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

import keras
from keras import backend

from kimm._src.blocks.conv2d import apply_conv2d_block
from kimm._src.blocks.depthwise_separation import (
Expand Down Expand Up @@ -55,6 +56,10 @@ def __init__(
)

self.set_properties(kwargs)
channels_axis = (
-1 if backend.image_data_format() == "channels_last" else -3
)

inputs = self.determine_input_tensor(
input_tensor,
self._input_shape,
Expand Down Expand Up @@ -93,8 +98,16 @@ def __init__(
s = s if current_layer_idx == 0 else 1
name = f"blocks_{current_block_idx}_{current_layer_idx}"
if block_type == "ds":
has_skip = x.shape[channels_axis] == c and s == 1
x = apply_depthwise_separation_block(
x, c, k, 1, s, activation="relu6", name=name
x,
c,
k,
1,
s,
activation="relu6",
has_skip=has_skip,
name=name,
)
elif block_type == "ir":
x = apply_inverted_residual_block(
Expand Down
11 changes: 10 additions & 1 deletion kimm/_src/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

import keras
from keras import backend
from keras import layers

from kimm._src.blocks.conv2d import apply_conv2d_block
Expand Down Expand Up @@ -124,6 +125,10 @@
padding = kwargs.pop("padding", None)

self.set_properties(kwargs)
channels_axis = (

Check warning on line 128 in kimm/_src/models/mobilenet_v3.py

View check run for this annotation

Codecov / codecov/patch

kimm/_src/models/mobilenet_v3.py#L128

Added line #L128 was not covered by tests
-1 if backend.image_data_format() == "channels_last" else -3
)

inputs = self.determine_input_tensor(
input_tensor,
self._input_shape,
Expand Down Expand Up @@ -181,6 +186,10 @@
),
}
if block_type in ("ds", "dsa"):
if block_type == "dsa":
has_skip = False

Check warning on line 190 in kimm/_src/models/mobilenet_v3.py

View check run for this annotation

Codecov / codecov/patch

kimm/_src/models/mobilenet_v3.py#L189-L190

Added lines #L189 - L190 were not covered by tests
else:
has_skip = x.shape[channels_axis] == c and s == 1

Check warning on line 192 in kimm/_src/models/mobilenet_v3.py

View check run for this annotation

Codecov / codecov/patch

kimm/_src/models/mobilenet_v3.py#L192

Added line #L192 was not covered by tests
x = apply_depthwise_separation_block(
x,
c,
Expand All @@ -193,7 +202,7 @@
se_gate_activation="hard_sigmoid",
se_make_divisible_number=8,
pw_activation=act if block_type == "dsa" else None,
skip=False if block_type == "dsa" else True,
has_skip=has_skip,
**_kwargs,
)
elif block_type == "ir":
Expand Down