Skip to content

Commit

Permalink
[ONNX] Remove ExportTypes (pytorch#137789)
Browse files Browse the repository at this point in the history
Remove deprecated ExportTypes and the `_exporter_states` module. Only protobuf (default) is supported going forward.

Differential Revision: [D64412947](https://our.internmc.facebook.com/intern/diff/D64412947)
Pull Request resolved: pytorch#137789
Approved by: https://github.com/titaiwangms, https://github.com/xadupre
  • Loading branch information
justinchuby authored and pytorchmergebot committed Oct 21, 2024
1 parent af0bc75 commit 6e38c87
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 65 deletions.
3 changes: 0 additions & 3 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"symbolic_opset19",
"symbolic_opset20",
# Enums
"ExportTypes",
"OperatorExportTypes",
"TrainingMode",
"TensorProtoDataType",
Expand Down Expand Up @@ -57,7 +56,6 @@
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode

from ._exporter_states import ExportTypes
from ._internal.exporter._onnx_program import ONNXProgram
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
Expand Down Expand Up @@ -115,7 +113,6 @@
# Set namespace for exposed private names
DiagnosticOptions.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
Expand Down
52 changes: 9 additions & 43 deletions torch/onnx/_internal/onnx_proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
from __future__ import annotations

import glob
import io
import os
import shutil
import zipfile
from typing import Any, Mapping
from typing import Any, Mapping, TYPE_CHECKING

import torch
import torch.jit._trace
import torch.serialization
from torch.onnx import _constants, _exporter_states, errors
from torch.onnx import errors
from torch.onnx._internal import jit_utils, registration


if TYPE_CHECKING:
import io


def export_as_test_case(
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
) -> str:
Expand Down Expand Up @@ -54,7 +56,6 @@ def export_as_test_case(
_export_file(
model_bytes,
os.path.join(test_case_dir, "model.onnx"),
_exporter_states.ExportTypes.PROTOBUF_FILE,
{},
)
data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
Expand Down Expand Up @@ -163,47 +164,12 @@ def export_data(data, value_info_proto, f: str) -> None:
def _export_file(
model_bytes: bytes,
f: io.BytesIO | str,
export_type: str,
export_map: Mapping[str, bytes],
) -> None:
"""export/write model bytes into directory/protobuf/zip"""
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)
elif export_type in {
_exporter_states.ExportTypes.ZIP_ARCHIVE,
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
}:
compression = (
zipfile.ZIP_DEFLATED
if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
else zipfile.ZIP_STORED
)
with zipfile.ZipFile(f, "w", compression=compression) as z:
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
raise ValueError(
f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
)
if not os.path.exists(f): # type: ignore[arg-type]
os.makedirs(f) # type: ignore[arg-type]

model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
opened_file.write(model_bytes)

for k, v in export_map.items():
weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
with torch.serialization._open_file_like(
weight_proto_file, "wb"
) as opened_file:
opened_file.write(v)
else:
raise ValueError("Unknown export type")
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)


def _add_onnxscript_fn(
Expand Down
20 changes: 4 additions & 16 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@
import torch.jit._trace
import torch.serialization
from torch import _C
from torch.onnx import ( # noqa: F401
_constants,
_deprecation,
_exporter_states,
errors,
symbolic_helper,
)
from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration

Expand Down Expand Up @@ -1423,9 +1417,6 @@ def _export(
):
assert GLOBALS.in_onnx_export is False

if export_type is None:
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE

if isinstance(model, torch.nn.DataParallel):
raise ValueError(
"torch.nn.DataParallel is not supported by ONNX "
Expand Down Expand Up @@ -1516,10 +1507,6 @@ def _export(
dynamic_axes=dynamic_axes,
)

# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = (
export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
)
if custom_opsets is None:
custom_opsets = {}

Expand All @@ -1540,6 +1527,7 @@ def _export(
getattr(model, "training", False), # type: ignore[arg-type]
)
_C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
defer_weight_export = False
if export_params:
(
proto,
Expand Down Expand Up @@ -1569,7 +1557,7 @@ def _export(
{},
opset_version,
dynamic_axes,
False,
defer_weight_export,
operator_export_type,
not verbose,
val_keep_init_as_ip,
Expand All @@ -1585,7 +1573,7 @@ def _export(
)
if verbose:
_C._jit_onnx_log("Exported graph: ", graph)
onnx_proto_utils._export_file(proto, f, export_type, export_map)
onnx_proto_utils._export_file(proto, f, export_map)
finally:
assert GLOBALS.in_onnx_export
GLOBALS.in_onnx_export = False
Expand Down
5 changes: 2 additions & 3 deletions torch/onnx/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch
import torch._C._onnx as _C_onnx
from torch import _C
from torch.onnx import _constants, _experimental, _exporter_states, utils
from torch.onnx import _constants, _experimental, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import onnx_proto_utils
from torch.types import Number
Expand Down Expand Up @@ -893,8 +893,7 @@ def verify_aten_graph(
graph, export_options, onnx_params_dict
)
model_f: str | io.BytesIO = io.BytesIO()
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
onnx_proto_utils._export_file(proto, model_f, export_map)

# NOTE: Verification is unstable. Try catch to emit information for debugging.
try:
Expand Down

0 comments on commit 6e38c87

Please sign in to comment.