-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
export_tflite
and export_onnx
and support channels_first
- Loading branch information
1 parent
01007b3
commit b8deb81
Showing
28 changed files
with
586 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,4 +161,8 @@ cython_debug/ | |
|
||
# Keras | ||
*.keras | ||
exported | ||
exported | ||
|
||
# Exported model | ||
*.tflite | ||
*.onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from kimm.export.export_onnx import export_onnx | ||
from kimm.export.export_tflite import export_tflite |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import pathlib | ||
import tempfile | ||
import typing | ||
|
||
from keras import backend | ||
from keras import layers | ||
from keras import models | ||
from keras import ops | ||
|
||
from kimm.models import BaseModel | ||
|
||
|
||
def _export_onnx_tf( | ||
model: BaseModel, | ||
inputs_as_nchw, | ||
export_path: typing.Union[str, pathlib.Path], | ||
): | ||
try: | ||
import tf2onnx | ||
import tf2onnx.tf_loader | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError( | ||
"Failed to import 'tf2onnx'. Please install it by the following " | ||
"instruction:\n'pip install tf2onnx'" | ||
) | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
temp_path = pathlib.Path(temp_dir, "temp_saved_model") | ||
model.export(temp_path) | ||
|
||
( | ||
graph_def, | ||
inputs, | ||
outputs, | ||
tensors_to_rename, | ||
) = tf2onnx.tf_loader.from_saved_model( | ||
temp_path, | ||
None, | ||
None, | ||
return_tensors_to_rename=True, | ||
) | ||
|
||
tf2onnx.convert.from_graph_def( | ||
graph_def, | ||
input_names=inputs, | ||
output_names=outputs, | ||
output_path=export_path, | ||
inputs_as_nchw=inputs_as_nchw, | ||
tensors_to_rename=tensors_to_rename, | ||
) | ||
|
||
|
||
def _export_onnx_torch( | ||
model: BaseModel, | ||
input_shape: typing.Union[int, int, int], | ||
export_path: typing.Union[str, pathlib.Path], | ||
): | ||
try: | ||
import torch | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError( | ||
"Failed to import 'torch'. Please install it before calling" | ||
"`export_onnx` using torch backend" | ||
) | ||
full_input_shape = [1] + list(input_shape) | ||
dummy_inputs = ops.ones(full_input_shape) | ||
scripted_model = torch.jit.trace(model, dummy_inputs).eval() | ||
torch.onnx.export(scripted_model, dummy_inputs, export_path) | ||
|
||
|
||
def export_onnx( | ||
model: BaseModel, | ||
input_shape: typing.Union[int, typing.Sequence[int]], | ||
export_path: typing.Union[str, pathlib.Path], | ||
batch_size: int = 1, | ||
use_nchw: bool = True, | ||
): | ||
if backend.backend() not in ("tensorflow", "torch"): | ||
raise ValueError( | ||
"Currently, `export_onnx` only supports tensorflow and torch " | ||
"backend" | ||
) | ||
try: | ||
import onnx | ||
import onnxoptimizer | ||
import onnxsim | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError( | ||
"Failed to import 'onnx', 'onnxsim' or 'onnxoptimizer'. Please " | ||
"install them by the following instruction:\n" | ||
"'pip install onnx onnxsim onnxoptimizer'" | ||
) | ||
|
||
if isinstance(input_shape, int): | ||
input_shape = [input_shape, input_shape, 3] | ||
elif len(input_shape) == 2: | ||
input_shape = [input_shape[0], input_shape[1], 3] | ||
elif len(input_shape) == 3: | ||
input_shape = input_shape | ||
if use_nchw: | ||
if backend.backend() == "torch": | ||
raise ValueError( | ||
"Currently, torch backend doesn't support `use_nchw=True`. " | ||
"You can use tensorflow backend to overcome this issue or " | ||
"set `use_nchw=False`. " | ||
"Note that there might be a significant performance " | ||
"degradation when using torch backend to export onnx due to " | ||
"the pre- and post-transpose of the Conv2D." | ||
) | ||
elif backend.backend() == "tensorflow": | ||
inputs_as_nchw = ["inputs"] | ||
else: | ||
inputs_as_nchw = None | ||
else: | ||
inputs_as_nchw = None | ||
|
||
# Fix input shape | ||
inputs = layers.Input( | ||
shape=input_shape, batch_size=batch_size, name="inputs" | ||
) | ||
outputs = model(inputs, training=False) | ||
model = models.Model(inputs, outputs) | ||
|
||
if backend.backend() == "tensorflow": | ||
_export_onnx_tf(model, inputs_as_nchw, export_path) | ||
elif backend.backend() == "torch": | ||
_export_onnx_torch(model, input_shape, export_path) | ||
|
||
# Further optimization | ||
model = onnx.load(export_path) | ||
model_simp, _ = onnxsim.simplify(model) | ||
model_simp = onnxoptimizer.optimize(model_simp) | ||
onnx.save(model_simp, export_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pytest | ||
from absl.testing import parameterized | ||
from keras import backend | ||
from keras.src import testing | ||
|
||
from kimm import export | ||
from kimm import models | ||
|
||
|
||
class ExportOnnxTest(testing.TestCase, parameterized.TestCase): | ||
def get_model(self): | ||
input_shape = [224, 224, 3] | ||
model = models.MobileNet050V3Small(include_preprocessing=False) | ||
return input_shape, model | ||
|
||
@pytest.mark.skipif( | ||
backend.backend() != "tensorflow", # TODO: test torch | ||
reason="Requires tensorflow or torch backend.", | ||
) | ||
def test_export_onnx_use(self): | ||
input_shape, model = self.get_model() | ||
|
||
temp_dir = self.get_temp_dir() | ||
|
||
if backend.backend() == "tensorflow": | ||
export.export_onnx(model, input_shape, f"{temp_dir}/model.onnx") | ||
elif backend.backend() == "torch": | ||
export.export_onnx( | ||
model, input_shape, f"{temp_dir}/model.onnx", use_nchw=False | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pathlib | ||
import tempfile | ||
import typing | ||
|
||
from keras import backend | ||
from keras import layers | ||
from keras import models | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
from kimm.models import BaseModel | ||
|
||
|
||
def export_tflite( | ||
model: BaseModel, | ||
input_shape: typing.Union[int, typing.Sequence[int]], | ||
export_path: typing.Union[str, pathlib.Path], | ||
export_dtype: typing.Literal["float32", "float16", "int8"] = "float32", | ||
representative_dataset: typing.Optional[typing.Iterator] = None, | ||
batch_size: int = 1, | ||
): | ||
if backend.backend() != "tensorflow": | ||
raise ValueError( | ||
"Currently, `export_tflite` only supports tensorflow backend" | ||
) | ||
if export_dtype not in ("float32", "float16", "int8"): | ||
raise ValueError( | ||
"`export_dtype` must be one of ('float32', 'float16', 'int8'). " | ||
f"Received: export_dtype={export_dtype}" | ||
) | ||
if export_dtype == "int8" and representative_dataset is None: | ||
raise ValueError( | ||
"For full integer quantization, a `representative_dataset` should " | ||
"be specified." | ||
) | ||
if isinstance(input_shape, int): | ||
input_shape = [input_shape, input_shape, 3] | ||
elif len(input_shape) == 2: | ||
input_shape = [input_shape[0], input_shape[1], 3] | ||
elif len(input_shape) == 3: | ||
input_shape = input_shape | ||
|
||
# Fix input shape | ||
inputs = layers.Input(shape=input_shape, batch_size=batch_size) | ||
outputs = model(inputs, training=False) | ||
model = models.Model(inputs, outputs) | ||
|
||
# Construct TFLiteConverter | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
temp_path = pathlib.Path(temp_dir, "temp_saved_model") | ||
model.export(temp_path) | ||
converter = tf.lite.TFLiteConverter.from_saved_model(str(temp_path)) | ||
|
||
# Configure converter | ||
if export_dtype != "float32": | ||
converter.optimizations = [tf.lite.Optimize.DEFAULT] | ||
if export_dtype == "int8": | ||
converter.target_spec.supported_ops = [ | ||
tf.lite.OpsSet.TFLITE_BUILTINS_INT8 | ||
] | ||
converter.inference_input_type = tf.int8 | ||
converter.inference_output_type = tf.int8 | ||
elif export_dtype == "float16": | ||
converter.target_spec.supported_types = [tf.float16] | ||
if representative_dataset is not None: | ||
converter.representative_dataset = representative_dataset | ||
|
||
# Convert | ||
tflite_model = converter.convert() | ||
|
||
# Export | ||
with open(export_path, "wb") as f: | ||
f.write(tflite_model) |
Oops, something went wrong.