diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 597f1653c2bd69..b25c823049dc5e 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -24,7 +24,6 @@ "symbolic_opset19", "symbolic_opset20", # Enums - "ExportTypes", "OperatorExportTypes", "TrainingMode", "TensorProtoDataType", @@ -57,7 +56,6 @@ from torch._C import _onnx as _C_onnx from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode -from ._exporter_states import ExportTypes from ._internal.exporter._onnx_program import ONNXProgram from ._internal.onnxruntime import ( is_onnxrt_backend_supported, @@ -115,7 +113,6 @@ # Set namespace for exposed private names DiagnosticOptions.__module__ = "torch.onnx" ExportOptions.__module__ = "torch.onnx" -ExportTypes.__module__ = "torch.onnx" JitScalarType.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py index 5fc181b180824a..19c31ab16f3803 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -4,19 +4,21 @@ from __future__ import annotations import glob -import io import os import shutil -import zipfile -from typing import Any, Mapping +from typing import Any, Mapping, TYPE_CHECKING import torch import torch.jit._trace import torch.serialization -from torch.onnx import _constants, _exporter_states, errors +from torch.onnx import errors from torch.onnx._internal import jit_utils, registration +if TYPE_CHECKING: + import io + + def export_as_test_case( model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str ) -> str: @@ -54,7 +56,6 @@ def export_as_test_case( _export_file( model_bytes, os.path.join(test_case_dir, "model.onnx"), - _exporter_states.ExportTypes.PROTOBUF_FILE, {}, ) data_set_dir = os.path.join(test_case_dir, "test_data_set_0") @@ -163,47 +164,12 @@ def export_data(data, value_info_proto, f: str) -> None: def _export_file( model_bytes: bytes, f: io.BytesIO | str, - export_type: str, export_map: Mapping[str, bytes], ) -> None: """export/write model bytes into directory/protobuf/zip""" - if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: - assert len(export_map) == 0 - with torch.serialization._open_file_like(f, "wb") as opened_file: - opened_file.write(model_bytes) - elif export_type in { - _exporter_states.ExportTypes.ZIP_ARCHIVE, - _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, - }: - compression = ( - zipfile.ZIP_DEFLATED - if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE - else zipfile.ZIP_STORED - ) - with zipfile.ZipFile(f, "w", compression=compression) as z: - z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes) - for k, v in export_map.items(): - z.writestr(k, v) - elif export_type == _exporter_states.ExportTypes.DIRECTORY: - if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type] - raise ValueError( - f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}" - ) - if not os.path.exists(f): # type: ignore[arg-type] - os.makedirs(f) # type: ignore[arg-type] - - model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type] - with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file: - opened_file.write(model_bytes) - - for k, v in export_map.items(): - weight_proto_file = os.path.join(f, k) # type: ignore[arg-type] - with torch.serialization._open_file_like( - weight_proto_file, "wb" - ) as opened_file: - opened_file.write(v) - else: - raise ValueError("Unknown export type") + assert len(export_map) == 0 + with torch.serialization._open_file_like(f, "wb") as opened_file: + opened_file.write(model_bytes) def _add_onnxscript_fn( diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 033d957f1d9295..3d94581c79b045 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -20,13 +20,7 @@ import torch.jit._trace import torch.serialization from torch import _C -from torch.onnx import ( # noqa: F401 - _constants, - _deprecation, - _exporter_states, - errors, - symbolic_helper, -) +from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401 from torch.onnx._globals import GLOBALS from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration @@ -1423,9 +1417,6 @@ def _export( ): assert GLOBALS.in_onnx_export is False - if export_type is None: - export_type = _exporter_states.ExportTypes.PROTOBUF_FILE - if isinstance(model, torch.nn.DataParallel): raise ValueError( "torch.nn.DataParallel is not supported by ONNX " @@ -1516,10 +1507,6 @@ def _export( dynamic_axes=dynamic_axes, ) - # TODO: Don't allocate a in-memory string for the protobuf - defer_weight_export = ( - export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE - ) if custom_opsets is None: custom_opsets = {} @@ -1540,6 +1527,7 @@ def _export( getattr(model, "training", False), # type: ignore[arg-type] ) _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) + defer_weight_export = False if export_params: ( proto, @@ -1569,7 +1557,7 @@ def _export( {}, opset_version, dynamic_axes, - False, + defer_weight_export, operator_export_type, not verbose, val_keep_init_as_ip, @@ -1585,7 +1573,7 @@ def _export( ) if verbose: _C._jit_onnx_log("Exported graph: ", graph) - onnx_proto_utils._export_file(proto, f, export_type, export_map) + onnx_proto_utils._export_file(proto, f, export_map) finally: assert GLOBALS.in_onnx_export GLOBALS.in_onnx_export = False diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index ddf66923364ae1..26810b116ffc01 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -26,7 +26,7 @@ import torch import torch._C._onnx as _C_onnx from torch import _C -from torch.onnx import _constants, _experimental, _exporter_states, utils +from torch.onnx import _constants, _experimental, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import onnx_proto_utils from torch.types import Number @@ -893,8 +893,7 @@ def verify_aten_graph( graph, export_options, onnx_params_dict ) model_f: str | io.BytesIO = io.BytesIO() - export_type = _exporter_states.ExportTypes.PROTOBUF_FILE - onnx_proto_utils._export_file(proto, model_f, export_type, export_map) + onnx_proto_utils._export_file(proto, model_f, export_map) # NOTE: Verification is unstable. Try catch to emit information for debugging. try: