Skip to content

Commit

Permalink
Add export_tflite and export_onnx and support channels_first
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jan 22, 2024
1 parent 01007b3 commit b8deb81
Show file tree
Hide file tree
Showing 28 changed files with 586 additions and 92 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,8 @@ cython_debug/

# Keras
*.keras
exported
exported

# Exported model
*.tflite
*.onnx
1 change: 1 addition & 0 deletions kimm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from kimm import export
from kimm import models # force to add models to the registry
from kimm.utils.model_registry import list_models

Expand Down
19 changes: 15 additions & 4 deletions kimm/blocks/base_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

from keras import backend
from keras import layers

from kimm.utils import make_divisible
Expand Down Expand Up @@ -43,7 +44,9 @@ def apply_conv2d_block(
)
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
input_channels = inputs.shape[-1]

channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
has_skip = add_skip and strides == 1 and input_channels == filters
x = inputs

Expand Down Expand Up @@ -74,7 +77,10 @@ def apply_conv2d_block(
name=f"{name}_dwconv2d",
)(x)
x = layers.BatchNormalization(
name=f"{name}_bn", momentum=bn_momentum, epsilon=bn_epsilon
axis=channels_axis,
name=f"{name}_bn",
momentum=bn_momentum,
epsilon=bn_epsilon,
)(x)
x = apply_activation(x, activation, name=name)
if has_skip:
Expand All @@ -91,7 +97,8 @@ def apply_se_block(
se_input_channels: typing.Optional[int] = None,
name: str = "se_block",
):
input_channels = inputs.shape[-1]
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
if se_input_channels is None:
se_input_channels = input_channels
if make_divisible_number is None:
Expand All @@ -102,7 +109,11 @@ def apply_se_block(
)

x = inputs
x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x)
x = layers.GlobalAveragePooling2D(
data_format=backend.image_data_format(),
keepdims=True,
name=f"{name}_mean",
)(x)
x = layers.Conv2D(
se_channels, 1, use_bias=True, name=f"{name}_conv_reduce"
)(x)
Expand Down
4 changes: 3 additions & 1 deletion kimm/blocks/depthwise_separation_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

from keras import backend
from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
Expand All @@ -23,7 +24,8 @@ def apply_depthwise_separation_block(
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "depthwise_separation_block",
):
input_channels = inputs.shape[-1]
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
has_skip = skip and (strides == 1 and input_channels == output_channels)

x = inputs
Expand Down
4 changes: 3 additions & 1 deletion kimm/blocks/inverted_residual_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

from keras import backend
from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
Expand All @@ -25,7 +26,8 @@ def apply_inverted_residual_block(
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "inverted_residual_block",
):
input_channels = inputs.shape[-1]
channels_axis = -1 if backend.image_data_format() == "channels_last" else -3
input_channels = inputs.shape[channels_axis]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

Expand Down
8 changes: 7 additions & 1 deletion kimm/blocks/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

from keras import backend
from keras import layers

from kimm import layers as kimm_layers
Expand All @@ -13,9 +14,13 @@ def apply_mlp_block(
use_bias: bool = True,
dropout_rate: float = 0.0,
use_conv_mlp: bool = False,
data_format: typing.Optional[str] = None,
name: str = "mlp_block",
):
input_dim = inputs.shape[-1]
if data_format is None:
data_format = backend.image_data_format()
dim_axis = -1 if data_format == "channels_last" else 1
input_dim = inputs.shape[dim_axis]
output_dim = output_dim or input_dim

x = inputs
Expand Down Expand Up @@ -71,6 +76,7 @@ def apply_transformer_block(
int(dim * mlp_ratio),
activation=activation,
dropout_rate=projection_dropout_rate,
data_format="channels_last", # TODO: let backend decides
name=f"{name}_mlp",
)
x = layers.Add()([residual_2, x])
Expand Down
2 changes: 2 additions & 0 deletions kimm/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from kimm.export.export_onnx import export_onnx
from kimm.export.export_tflite import export_tflite
133 changes: 133 additions & 0 deletions kimm/export/export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pathlib
import tempfile
import typing

from keras import backend
from keras import layers
from keras import models
from keras import ops

from kimm.models import BaseModel


def _export_onnx_tf(
model: BaseModel,
inputs_as_nchw,
export_path: typing.Union[str, pathlib.Path],
):
try:
import tf2onnx
import tf2onnx.tf_loader
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Failed to import 'tf2onnx'. Please install it by the following "
"instruction:\n'pip install tf2onnx'"
)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = pathlib.Path(temp_dir, "temp_saved_model")
model.export(temp_path)

(
graph_def,
inputs,
outputs,
tensors_to_rename,
) = tf2onnx.tf_loader.from_saved_model(
temp_path,
None,
None,
return_tensors_to_rename=True,
)

tf2onnx.convert.from_graph_def(
graph_def,
input_names=inputs,
output_names=outputs,
output_path=export_path,
inputs_as_nchw=inputs_as_nchw,
tensors_to_rename=tensors_to_rename,
)


def _export_onnx_torch(
model: BaseModel,
input_shape: typing.Union[int, int, int],
export_path: typing.Union[str, pathlib.Path],
):
try:
import torch
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Failed to import 'torch'. Please install it before calling"
"`export_onnx` using torch backend"
)
full_input_shape = [1] + list(input_shape)
dummy_inputs = ops.ones(full_input_shape)
scripted_model = torch.jit.trace(model, dummy_inputs).eval()
torch.onnx.export(scripted_model, dummy_inputs, export_path)


def export_onnx(
model: BaseModel,
input_shape: typing.Union[int, typing.Sequence[int]],
export_path: typing.Union[str, pathlib.Path],
batch_size: int = 1,
use_nchw: bool = True,
):
if backend.backend() not in ("tensorflow", "torch"):
raise ValueError(
"Currently, `export_onnx` only supports tensorflow and torch "
"backend"
)
try:
import onnx
import onnxoptimizer
import onnxsim
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Failed to import 'onnx', 'onnxsim' or 'onnxoptimizer'. Please "
"install them by the following instruction:\n"
"'pip install onnx onnxsim onnxoptimizer'"
)

if isinstance(input_shape, int):
input_shape = [input_shape, input_shape, 3]
elif len(input_shape) == 2:
input_shape = [input_shape[0], input_shape[1], 3]
elif len(input_shape) == 3:
input_shape = input_shape
if use_nchw:
if backend.backend() == "torch":
raise ValueError(
"Currently, torch backend doesn't support `use_nchw=True`. "
"You can use tensorflow backend to overcome this issue or "
"set `use_nchw=False`. "
"Note that there might be a significant performance "
"degradation when using torch backend to export onnx due to "
"the pre- and post-transpose of the Conv2D."
)
elif backend.backend() == "tensorflow":
inputs_as_nchw = ["inputs"]
else:
inputs_as_nchw = None
else:
inputs_as_nchw = None

# Fix input shape
inputs = layers.Input(
shape=input_shape, batch_size=batch_size, name="inputs"
)
outputs = model(inputs, training=False)
model = models.Model(inputs, outputs)

if backend.backend() == "tensorflow":
_export_onnx_tf(model, inputs_as_nchw, export_path)
elif backend.backend() == "torch":
_export_onnx_torch(model, input_shape, export_path)

# Further optimization
model = onnx.load(export_path)
model_simp, _ = onnxsim.simplify(model)
model_simp = onnxoptimizer.optimize(model_simp)
onnx.save(model_simp, export_path)
30 changes: 30 additions & 0 deletions kimm/export/export_onnx_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from absl.testing import parameterized
from keras import backend
from keras.src import testing

from kimm import export
from kimm import models


class ExportOnnxTest(testing.TestCase, parameterized.TestCase):
def get_model(self):
input_shape = [224, 224, 3]
model = models.MobileNet050V3Small(include_preprocessing=False)
return input_shape, model

@pytest.mark.skipif(
backend.backend() != "tensorflow", # TODO: test torch
reason="Requires tensorflow or torch backend.",
)
def test_export_onnx_use(self):
input_shape, model = self.get_model()

temp_dir = self.get_temp_dir()

if backend.backend() == "tensorflow":
export.export_onnx(model, input_shape, f"{temp_dir}/model.onnx")
elif backend.backend() == "torch":
export.export_onnx(
model, input_shape, f"{temp_dir}/model.onnx", use_nchw=False
)
72 changes: 72 additions & 0 deletions kimm/export/export_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pathlib
import tempfile
import typing

from keras import backend
from keras import layers
from keras import models
from keras.src.utils.module_utils import tensorflow as tf

from kimm.models import BaseModel


def export_tflite(
model: BaseModel,
input_shape: typing.Union[int, typing.Sequence[int]],
export_path: typing.Union[str, pathlib.Path],
export_dtype: typing.Literal["float32", "float16", "int8"] = "float32",
representative_dataset: typing.Optional[typing.Iterator] = None,
batch_size: int = 1,
):
if backend.backend() != "tensorflow":
raise ValueError(
"Currently, `export_tflite` only supports tensorflow backend"
)
if export_dtype not in ("float32", "float16", "int8"):
raise ValueError(
"`export_dtype` must be one of ('float32', 'float16', 'int8'). "
f"Received: export_dtype={export_dtype}"
)
if export_dtype == "int8" and representative_dataset is None:
raise ValueError(
"For full integer quantization, a `representative_dataset` should "
"be specified."
)
if isinstance(input_shape, int):
input_shape = [input_shape, input_shape, 3]
elif len(input_shape) == 2:
input_shape = [input_shape[0], input_shape[1], 3]
elif len(input_shape) == 3:
input_shape = input_shape

# Fix input shape
inputs = layers.Input(shape=input_shape, batch_size=batch_size)
outputs = model(inputs, training=False)
model = models.Model(inputs, outputs)

# Construct TFLiteConverter
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = pathlib.Path(temp_dir, "temp_saved_model")
model.export(temp_path)
converter = tf.lite.TFLiteConverter.from_saved_model(str(temp_path))

# Configure converter
if export_dtype != "float32":
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if export_dtype == "int8":
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
elif export_dtype == "float16":
converter.target_spec.supported_types = [tf.float16]
if representative_dataset is not None:
converter.representative_dataset = representative_dataset

# Convert
tflite_model = converter.convert()

# Export
with open(export_path, "wb") as f:
f.write(tflite_model)
Loading

0 comments on commit b8deb81

Please sign in to comment.