Skip to content

Commit

Permalink
[ONNX] Remove deprecated export_to_pretty_string (pytorch#137790)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#137790
Approved by: https://github.com/titaiwangms, https://github.com/xadupre
ghstack dependencies: pytorch#137789
  • Loading branch information
justinchuby authored and pytorchmergebot committed Oct 21, 2024
1 parent 07cc4bd commit c6609ec
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 119 deletions.
1 change: 0 additions & 1 deletion docs/source/onnx_torchscript.rst
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,6 @@ Functions
^^^^^^^^^

.. autofunction:: export
.. autofunction:: export_to_pretty_string
.. autofunction:: register_custom_op_symbolic
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
Expand Down
85 changes: 48 additions & 37 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(self, x):

x = torch.ones(3, 3)
f = io.BytesIO()
torch.onnx.export(AddmmModel(), x, f, verbose=False)
torch.onnx.export(AddmmModel(), x, f)

def test_onnx_transpose_incomplete_tensor_type(self):
# Smoke test to get us into the state where we are attempting to export
Expand Down Expand Up @@ -115,7 +115,8 @@ def foo(x):

traced = torch.jit.trace(foo, (torch.rand([2])))

torch.onnx.export_to_pretty_string(traced, (torch.rand([2]),))
f = io.BytesIO()
torch.onnx.export(traced, (torch.rand([2]),), f)

def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -125,7 +126,8 @@ def forward(self, x):
return x + x

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

@common_utils.suppress_warnings
def test_onnx_export_func_with_warnings(self):
Expand All @@ -138,9 +140,8 @@ def forward(self, x):
return func_with_warning(x)

# no exception
torch.onnx.export_to_pretty_string(
WarningTest(), torch.randn(42), verbose=False
)
f = io.BytesIO()
torch.onnx.export(WarningTest(), torch.randn(42), f)

def test_onnx_export_script_python_fail(self):
class PythonModule(torch.jit.ScriptModule):
Expand All @@ -161,7 +162,7 @@ def forward(self, x):
mte = ModuleToExport()
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.nn.Module):
Expand All @@ -179,7 +180,8 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand All @@ -198,7 +200,8 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -212,7 +215,8 @@ def forward(self, x):
return x

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

@common_utils.suppress_warnings
def test_onnx_export_script_truediv(self):
Expand All @@ -224,9 +228,8 @@ def forward(self, x):

mte = ModuleToExport()

torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), verbose=False
)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f)

def test_onnx_export_script_non_alpha_add_sub(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -236,7 +239,8 @@ def forward(self, x):
return bs - 1

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.rand(3, 4),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.rand(3, 4),), f)

def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -247,7 +251,8 @@ def forward(self, x):
return x

mte = ModuleToExport()
torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f)

def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand Down Expand Up @@ -277,7 +282,8 @@ def forward(self, x):
torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4)
)
self.assertEqual(result, reference)
torch.onnx.export_to_pretty_string(mte, (torch.ones(2, 3),), verbose=False)
f = io.BytesIO()
torch.onnx.export(mte, (torch.ones(2, 3),), f)

def test_onnx_export_speculate(self):
class Foo(torch.jit.ScriptModule):
Expand Down Expand Up @@ -312,8 +318,10 @@ def transpose(x):
f1 = Foo(transpose)
f2 = Foo(linear)

torch.onnx.export_to_pretty_string(f1, (torch.ones(1, 10, dtype=torch.float),))
torch.onnx.export_to_pretty_string(f2, (torch.ones(1, 10, dtype=torch.float),))
f = io.BytesIO()
torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f)
f = io.BytesIO()
torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f)

def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
Expand All @@ -326,17 +334,20 @@ def forward(self, x):
return reshaped

foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)))
f = io.BytesIO()
torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f)

def test_listconstruct_erasure(self):
class FooMod(torch.nn.Module):
def forward(self, x):
mask = x < 0.0
return x[mask]

torch.onnx.export_to_pretty_string(
f = io.BytesIO()
torch.onnx.export(
FooMod(),
(torch.rand(3, 4),),
f,
add_node_names=False,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
Expand All @@ -351,13 +362,10 @@ def forward(self, x):
retval += torch.sum(x[0:i], dim=0)
return retval

mod = DynamicSliceExportMod()

input = torch.rand(3, 4, 5)

torch.onnx.export_to_pretty_string(
DynamicSliceExportMod(), (input,), opset_version=10
)
f = io.BytesIO()
torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10)

def test_export_dict(self):
class DictModule(torch.nn.Module):
Expand All @@ -368,10 +376,12 @@ def forward(self, x_in: torch.Tensor) -> Dict[str, torch.Tensor]:
mod = DictModule()
mod.train(False)

torch.onnx.export_to_pretty_string(mod, (x_in,))
f = io.BytesIO()
torch.onnx.export(mod, (x_in,), f)

with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
torch.onnx.export_to_pretty_string(torch.jit.script(mod), (x_in,))
f = io.BytesIO()
torch.onnx.export(torch.jit.script(mod), (x_in,), f)

def test_source_range_propagation(self):
class ExpandingModule(torch.nn.Module):
Expand Down Expand Up @@ -497,11 +507,11 @@ def forward(self, box_regression: Tensor, proposals: List[Tensor]):
proposal = [torch.randn(2, 4), torch.randn(2, 4)]

with self.assertRaises(RuntimeError) as cm:
onnx_model = io.BytesIO()
f = io.BytesIO()
torch.onnx.export(
model,
(box_regression, proposal),
onnx_model,
f,
)

def test_initializer_sequence(self):
Expand Down Expand Up @@ -637,7 +647,7 @@ def forward(self, x):

x = torch.randn(1, 2, 3, requires_grad=True)
f = io.BytesIO()
torch.onnx.export(Model(), x, f)
torch.onnx.export(Model(), (x,), f)
model = onnx.load(f)
model.ir_version = 0

Expand Down Expand Up @@ -744,7 +754,7 @@ def forward(self, x):

f = io.BytesIO()
with warnings.catch_warnings(record=True):
torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
torch.onnx.export(MyDrop(), (eg,), f)

def test_pack_padded_pad_packed_trace(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
Expand Down Expand Up @@ -791,7 +801,7 @@ def forward(self, x, seq_lens):
self.assertEqual(grad, grad_traced)

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)
torch.onnx.export(m, (x, seq_lens), f)

# Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch.
@common_utils.suppress_warnings
Expand Down Expand Up @@ -851,7 +861,7 @@ def forward(self, x, seq_lens):
self.assertEqual(grad, grad_traced)

f = io.BytesIO()
torch.onnx.export(m, (x, seq_lens), f, verbose=False)
torch.onnx.export(m, (x, seq_lens), f)

def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self):
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
Expand Down Expand Up @@ -931,7 +941,8 @@ class Mod(torch.nn.Module):
def forward(self, x, w):
return torch.matmul(x, w).detach()

torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))
f = io.BytesIO()
torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)

def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
Expand Down Expand Up @@ -1088,12 +1099,12 @@ def sym_scatter_max(g, src, index, dim, out, dim_size):
torch.onnx.register_custom_op_symbolic(
"torch_scatter::scatter_max", sym_scatter_max, 1
)
f = io.BytesIO()
with torch.no_grad():
torch.onnx.export(
m,
(src, idx),
"mymodel.onnx",
verbose=False,
f,
opset_version=13,
custom_opsets={"torch_scatter": 1},
do_constant_folding=True,
Expand Down Expand Up @@ -1176,7 +1187,7 @@ def forward(self, x):
model = Net(C).cuda().half()
x = torch.randn(N, C).cuda().half()
f = io.BytesIO()
torch.onnx.export(model, x, f, opset_version=14)
torch.onnx.export(model, (x,), f, opset_version=14)
onnx_model = onnx.load_from_string(f.getvalue())
const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"]
self.assertNotEqual(len(const_node), 0)
Expand Down
2 changes: 0 additions & 2 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"JitScalarType",
# Public functions
"export",
"export_to_pretty_string",
"is_in_onnx_export",
"select_model_mode_for_export",
"register_custom_op_symbolic",
Expand Down Expand Up @@ -68,7 +67,6 @@
from .utils import (
_run_symbolic_function,
_run_symbolic_method,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
Expand Down
79 changes: 0 additions & 79 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"model_signature",
"warn_on_static_input_change",
"unpack_quantized_tensor",
"export_to_pretty_string",
"unconvertible_ops",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
Expand Down Expand Up @@ -1140,84 +1139,6 @@ def _model_to_graph(
return graph, params_dict, torch_out


@torch._disable_dynamo
@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead")
def export_to_pretty_string(
model,
args,
export_params=True,
verbose=False,
training=_C_onnx.TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
export_type=None,
google_printer=False,
opset_version=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
add_node_names=True,
do_constant_folding=True,
dynamic_axes=None,
):
"""Similar to :func:`export`, but returns a text representation of the ONNX model.
Only differences in args listed below. All other args are the same
as :func:`export`.
Args:
add_node_names (bool, default True): Whether or not to set
NodeProto.name. This makes no difference unless
``google_printer=True``.
google_printer (bool, default False): If False, will return a custom,
compact representation of the model. If True will return the
protobuf's `Message::DebugString()`, which is more verbose.
Returns:
A UTF-8 str containing a human-readable representation of the ONNX model.
"""
if opset_version is None:
opset_version = _constants.ONNX_DEFAULT_OPSET
if custom_opsets is None:
custom_opsets = {}
GLOBALS.export_onnx_opset_version = opset_version
GLOBALS.operator_export_type = operator_export_type

with exporter_context(model, training, verbose):
val_keep_init_as_ip = _decide_keep_init_as_input(
keep_initializers_as_inputs, operator_export_type, opset_version
)
val_add_node_names = _decide_add_node_names(
add_node_names, operator_export_type
)
val_do_constant_folding = _decide_constant_folding(
do_constant_folding, operator_export_type, training
)
args = _decide_input_format(model, args)
graph, params_dict, torch_out = _model_to_graph(
model,
args,
verbose,
input_names,
output_names,
operator_export_type,
val_do_constant_folding,
training=training,
dynamic_axes=dynamic_axes,
)

return graph._pretty_print_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
False,
operator_export_type,
google_printer,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
)


@_deprecation.deprecated("2.5", "the future", "avoid using this function")
def unconvertible_ops(
model,
Expand Down

0 comments on commit c6609ec

Please sign in to comment.